diff --git a/dubhe-tadl/__init__.py b/dubhe-tadl/__init__.py new file mode 100644 index 0000000..ff149fc --- /dev/null +++ b/dubhe-tadl/__init__.py @@ -0,0 +1,2 @@ +# from .log import init_logger +# init_logger() \ No newline at end of file diff --git a/dubhe-tadl/base_mutator.py b/dubhe-tadl/base_mutator.py new file mode 100644 index 0000000..dbbc36f --- /dev/null +++ b/dubhe-tadl/base_mutator.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import torch.nn as nn +from .mutables import Mutable, MutableScope, InputChoice +from .utils import StructuredMutableTreeNode + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class BaseMutator(nn.Module): + """ + A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing + callbacks that are called in ``forward`` in mutables. + + Parameters + ---------- + model : nn.Module + PyTorch model to apply mutator on. + """ + + def __init__(self, model): + super().__init__() + self.__dict__["model"] = model + self._structured_mutables = self._parse_search_space(self.model) + + def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None): + if memo is None: + memo = set() + if root is None: + root = StructuredMutableTreeNode(None) + if module not in memo: + memo.add(module) + + if isinstance(module, Mutable): + if nested_detection is not None: + raise RuntimeError("Cannot have nested search space. Error at {} in {}" + .format(module, nested_detection)) + module.name = prefix + + module.set_mutator(self) + root = root.add_child(module) + if not isinstance(module, MutableScope): + nested_detection = module + if isinstance(module, InputChoice): + for k in module.choose_from: + if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: + raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY." + .format(k, module.key)) + for name, submodule in module._modules.items(): + + if submodule is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + self._parse_search_space(submodule, root, submodule_prefix, memo=memo, + nested_detection=nested_detection) + return root + + @property + def mutables(self): + """ + A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`. + Modules are yielded in the order that they are defined in ``__init__``. + For mutables with their keys appearing multiple times, only the first one will appear. + """ + return self._structured_mutables + + @property + def undedup_mutables(self): + return self._structured_mutables.traverse(deduplicate=False) + + def forward(self, *inputs): + """ + Warnings + -------- + Don't call forward of a mutator. + """ + raise RuntimeError("Forward is undefined for mutators.") + + def __setattr__(self, name, value): + if name == "model": + raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to " + "include you network, as it will include all parameters in model into the mutator.") + return super().__setattr__(name, value) + + def enter_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is entered. + + Parameters + ---------- + mutable_scope : MutableScope + The mutable scope that is entered. + """ + pass + + def exit_mutable_scope(self, mutable_scope): + """ + Callback when forward of a MutableScope is exited. + + Parameters + ---------- + mutable_scope : MutableScope + The mutable scope that is exited. + """ + pass + + def on_forward_layer_choice(self, mutable, *args, **kwargs): + """ + Callbacks of forward in LayerChoice. + + Parameters + ---------- + mutable : LayerChoice + Module whose forward is called. + args : list of torch.Tensor + The arguments of its forward function. + kwargs : dict + The keyword arguments of its forward function. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + Output tensor and mask. + """ + raise NotImplementedError + + def on_forward_input_choice(self, mutable, tensor_list): + """ + Callbacks of forward in InputChoice. + + Parameters + ---------- + mutable : InputChoice + Mutable that is called. + tensor_list : list of torch.Tensor + The arguments mutable is called with. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + Output tensor and mask. + """ + raise NotImplementedError + + def export(self): + """ + Export the data of all decisions. This should output the decisions of all the mutables, so that the whole + network can be fully determined with these decisions for further training from scratch. + + Returns + ------- + dict + Mappings from mutable keys to decisions. + """ + raise NotImplementedError diff --git a/dubhe-tadl/base_trainer.py b/dubhe-tadl/base_trainer.py new file mode 100644 index 0000000..2e7a4a2 --- /dev/null +++ b/dubhe-tadl/base_trainer.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod + + +class BaseTrainer(ABC): + + @abstractmethod + def train(self): + """ + Override the method to train. + """ + raise NotImplementedError + + @abstractmethod + def validate(self): + """ + Override the method to validate. + """ + raise NotImplementedError + + @abstractmethod + def export(self, file): + """ + Override the method to export to file. + + Parameters + ---------- + file : str + File path to export to. + """ + raise NotImplementedError + + @abstractmethod + def checkpoint(self): + """ + Override to dump a checkpoint. + """ + raise NotImplementedError diff --git a/dubhe-tadl/callbacks.py b/dubhe-tadl/callbacks.py new file mode 100644 index 0000000..c5443fe --- /dev/null +++ b/dubhe-tadl/callbacks.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import os + +import torch +import torch.nn as nn + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +class Callback: + """ + Callback provides an easy way to react to events like begin/end of epochs. + """ + + def __init__(self): + self.model = None + self.optimizer = None + self.mutator = None + self.trainer = None + + def build(self, model, optimizer, mutator, trainer): + """ + Callback needs to be built with model, mutator, trainer, to get updates from them. + + Parameters + ---------- + model : nn.Module + Model to be trained. + mutator : nn.Module + Mutator that mutates the model. + trainer : BaseTrainer + Trainer that is to call the callback. + """ + self.model = model + self.optimizer = optimizer + self.mutator = mutator + self.trainer = trainer + + def on_epoch_begin(self, epoch): + """ + Implement this to do something at the begin of epoch. + + Parameters + ---------- + epoch : int + Epoch number, starting from 0. + """ + pass + + def on_epoch_end(self, epoch): + """ + Implement this to do something at the end of epoch. + + Parameters + ---------- + epoch : int + Epoch number, starting from 0. + """ + pass + + def on_batch_begin(self, epoch): + pass + + def on_batch_end(self, epoch): + pass + + +class LRSchedulerCallback(Callback): + """ + Calls scheduler on every epoch ends. + + Parameters + ---------- + scheduler : LRScheduler + Scheduler to be called. + """ + def __init__(self, scheduler, mode="epoch"): + super().__init__() + assert mode == "epoch" + self.scheduler = scheduler + self.mode = mode + + def on_epoch_end(self, epoch): + """ + Call ``self.scheduler.step()`` on epoch end. + """ + self.scheduler.step() + + +class ArchitectureCheckpoint(Callback): + """ + Calls ``trainer.export()`` on every epoch ends. + + Parameters + ---------- + checkpoint_dir : str + Location to save checkpoints. + """ + def __init__(self, checkpoint_dir): + super().__init__() + self.checkpoint_dir = checkpoint_dir + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def on_epoch_end(self, epoch): + """ + Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end. + """ + dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) + _logger.info("Saving architecture to %s", dest_path) + self.trainer.export(dest_path) + +class BestArchitectureCheckpoint(Callback): + """ + Calls ``trainer.export()`` on final epoch ends. + + Parameters + ---------- + checkpoint_path : str + Location to save checkpoints. + """ + def __init__(self, checkpoint_path, epoches): + super().__init__() + self.epoches = epoches + self.checkpoint_path = checkpoint_path + + def on_epoch_end(self, epoch): + """ + Dump to ``./best_selected_space.json`` on epoch end. + """ + if epoch == self.epoches -1: + _logger.info("Saving architecture to %s", self.checkpoint_path) + self.trainer.export(self.checkpoint_path) + +class ModelCheckpoint(Callback): + """ + Calls ``trainer.export()`` on every epoch ends. + + Parameters + ---------- + checkpoint_dir : str + Location to save checkpoints. + """ + def __init__(self, checkpoint_dir): + super().__init__() + self.checkpoint_dir = checkpoint_dir + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def on_epoch_end(self, epoch): + """ + Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end. + ``DataParallel`` object will have their inside modules exported. + """ + if isinstance(self.model, nn.DataParallel): + child_model_state_dict = self.model.module.state_dict() + else: + child_model_state_dict = self.model.state_dict() + + save_state = {'child_model_state_dict': child_model_state_dict, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'epoch': epoch} + + dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) + _logger.info("Saving model to %s", dest_path) + torch.save(save_state, dest_path) diff --git a/dubhe-tadl/classic_nas/fixed.py b/dubhe-tadl/classic_nas/fixed.py new file mode 100644 index 0000000..9298c56 --- /dev/null +++ b/dubhe-tadl/classic_nas/fixed.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging + +import sys +sys.path.append('..'+ '/' + '..') +from pytorch.mutables import InputChoice, LayerChoice, MutableScope +from pytorch.mutator import Mutator +from pytorch.utils import to_list + + +_logger = logging.getLogger(__name__) +#_logger.setLevel(logging.INFO) + +class FixedArchitecture(Mutator): + """ + Fixed architecture mutator that always selects a certain graph. + + Parameters + ---------- + model : nn.Module + A mutable network. + fixed_arc : dict + Preloaded architecture object. + strict : bool + Force everything that appears in ``fixed_arc`` to be used at least once. + """ + + def __init__(self, model, fixed_arc, strict=True): + super().__init__(model) + self._fixed_arc = fixed_arc + + mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) + fixed_arc_keys = set(self._fixed_arc.keys()) + if fixed_arc_keys - mutable_keys: + raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) + if mutable_keys - fixed_arc_keys: + raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) + + def _from_human_readable_architecture(self, human_arc): + # convert from an exported architecture + #print('human_arc',human_arc) + result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. + #print('result_arc',result_arc) + # First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, + # which means {"op1": [0, ]} ir {"op1": ["conv", ]} + result_arc = {k: v['_value'] if isinstance(v['_value'], list) else [v['_value']] for k, v in result_arc.items()} + # Second, infer which ones are multi-hot arrays and which ones are in human-readable format. + # This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. + # Here, we assume an multihot array has to be a boolean array or a float array and matches the length. + for mutable in self.mutables: + + if mutable.key not in result_arc: + continue # skip silently + choice_arr = result_arc[mutable.key] + + if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): + if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ + (isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): + # multihot, do nothing + continue + if isinstance(mutable, LayerChoice): + + choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] + + choice_arr = [i in choice_arr for i in range(len(mutable))] + + elif isinstance(mutable, InputChoice): + choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] + choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] + result_arc[mutable.key] = choice_arr + return result_arc + + def sample_search(self): + """ + Always returns the fixed architecture. + """ + return self._fixed_arc + + def sample_final(self): + """ + Always returns the fixed architecture. + """ + return self._fixed_arc + + def replace_layer_choice(self, module=None, prefix=""): + """ + Replace layer choices with selected candidates. It's done with best effort. + In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them. + If single choice, replace the module with a normal module. + + Parameters + ---------- + module : nn.Module + Module to be processed. + prefix : str + Module name under global namespace. + """ + if module is None: + module = self.model + for name, mutable in module.named_children(): + global_name = (prefix + "." if prefix else "") + name + if isinstance(mutable, LayerChoice): + chosen = self._fixed_arc[mutable.key] + if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: + # sum is one, max is one, there has to be an only one + # this is compatible with both integer arrays, boolean arrays and float arrays + _logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) + setattr(module, name, mutable[chosen.index(1)]) + else: + if mutable.return_mask: + _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ + "LayerChoice will not be replaced.") + # remove unused parameters + for ch, n in zip(chosen, mutable.names): + if ch == 0 and not isinstance(ch, float): + setattr(mutable, n, None) + else: + self.replace_layer_choice(mutable, global_name) + + +def apply_fixed_architecture(model, fixed_arc): + """ + Load architecture from `fixed_arc` and apply to model. + + Parameters + ---------- + model : torch.nn.Module + Model with mutables. + fixed_arc : str or dict + Path to the JSON that stores the architecture, or dict that stores the exported architecture. + + Returns + ------- + FixedArchitecture + Mutator that is responsible for fixes the graph. + """ + + if isinstance(fixed_arc, str): + with open(fixed_arc) as f: + fixed_arc = json.load(f) + architecture = FixedArchitecture(model, fixed_arc) + architecture.reset() + + # for the convenience of parameters counting + architecture.replace_layer_choice() + return architecture diff --git a/dubhe-tadl/classic_nas/mnist.py b/dubhe-tadl/classic_nas/mnist.py new file mode 100644 index 0000000..66b69e4 --- /dev/null +++ b/dubhe-tadl/classic_nas/mnist.py @@ -0,0 +1,206 @@ +""" +A deep MNIST classifier using convolutional layers. + +This file is a modification of the official pytorch mnist example: +https://github.com/pytorch/examples/blob/master/mnist/main.py +""" + +import os +import argparse +import logging +import sys +sys.path.append('..'+ '/' + '..') +from collections import OrderedDict + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +from pytorch.mutables import LayerChoice, InputChoice +from mutator import ClassicMutator +import numpy as np +import time +import json + +logger = logging.getLogger('mnist_AutoML') + + +class Net(nn.Module): + def __init__(self, hidden_size): + super(Net, self).__init__() + # two options of conv1 + self.conv1 = LayerChoice(OrderedDict([ + ("conv5x5", nn.Conv2d(1, 20, 5, 1)), + ("conv3x3", nn.Conv2d(1, 20, 3, 1)) + ]), key='first_conv') + # two options of mid_conv + self.mid_conv = LayerChoice([ + nn.Conv2d(20, 20, 3, 1, padding=1), + nn.Conv2d(20, 20, 5, 1, padding=2) + ], key='mid_conv') + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4*4*50, hidden_size) + self.fc2 = nn.Linear(hidden_size, 10) + # skip connection over mid_conv + self.input_switch = InputChoice(n_candidates=2, + n_chosen=1, + key='skip') + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + old_x = x + x = F.relu(self.mid_conv(x)) + zero_x = torch.zeros_like(old_x) + skip_x = self.input_switch([zero_x, old_x]) + x = torch.add(x, skip_x) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4*4*50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args['log_interval'] == 0: + logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction='sum').item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + accuracy = 100. * correct / len(test_loader.dataset) + + logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), accuracy)) + + return accuracy + + +def main(args): + global_result={'accuarcy':[]} + use_cuda = not args['no_cuda'] and torch.cuda.is_available() + + torch.manual_seed(args['seed']) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + + data_dir = args['data_dir'] + + train_loader = torch.utils.data.DataLoader( + datasets.MNIST(data_dir, train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args['batch_size'], shuffle=True, **kwargs) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST(data_dir, train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=1000, shuffle=True, **kwargs) + + hidden_size = args['hidden_size'] + + model = Net(hidden_size=hidden_size).to(device) + #np.random.seed(42) + + #x = np.random.rand(2,1,28,28).astype(np.float32) + + #x= torch.from_numpy(x).to(device) + ClassicMutator(model,trial_id=args['trial_id'],selected_path=args["selected_space_path"],search_space_path=args["search_space_path"]) + + #y=model(x) + #print(y) + + optimizer = optim.SGD(model.parameters(), lr=args['lr'], + momentum=args['momentum']) + + for epoch in range(1, args['epochs'] + 1): + + + train(args, model, device, train_loader, optimizer, epoch) + test_acc = test(args, model, device, test_loader) + print({"type":"accuracy","result":{"sequence":epoch,"category":"epoch","value":test_acc}} ) + global_result['accuarcy'].append(test_acc) + + return global_result + +def dump_global_result(args,global_result): + with open(args['result_path'], "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=True, indent=2) + +def get_params(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument("--data_dir", type=str, + default='./data', help="data directory") + parser.add_argument("--selected_space_path", type=str, + default='./selected_space.json', help="selected_space_path") + parser.add_argument("--search_space_path", type=str, + default='./selected_space.json', help="search_space_path") + parser.add_argument("--result_path", type=str, + default='./result.json', help="result_path") + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument("--hidden_size", type=int, default=512, metavar='N', + help='hidden layer size (default: 512)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--no_cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--log_interval', type=int, default=1000, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + + args, _ = parser.parse_known_args() + return args + + +if __name__ == '__main__': + try: + start=time.time() + params = vars(get_params()) + global_result = main(params) + global_result['cost_time'] = str(time.time() - start) +'s' + dump_global_result(params,global_result) + except Exception as exception: + logger.exception(exception) + raise diff --git a/dubhe-tadl/classic_nas/model.py b/dubhe-tadl/classic_nas/model.py new file mode 100644 index 0000000..2065dd2 --- /dev/null +++ b/dubhe-tadl/classic_nas/model.py @@ -0,0 +1,52 @@ +import os +import argparse +import logging +import sys + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +from pytorch.mutables import LayerChoice, InputChoice +from mutator import ClassicMutator +import numpy as np + +class Net(nn.Module): + def __init__(self, hidden_size): + super(Net, self).__init__() + # two options of conv1 + self.conv1 = LayerChoice(OrderedDict([ + ("conv5x5", nn.Conv2d(1, 20, 5, 1)), + ("conv3x3", nn.Conv2d(1, 20, 3, 1)) + ]), key='conv1') + # two options of mid_conv + self.mid_conv = LayerChoice(OrderedDict([ + ("conv3x3",nn.Conv2d(20, 20, 3, 1, padding=1)), + ("conv5x5",nn.Conv2d(20, 20, 5, 1, padding=2)) + ]), key='mid_conv') + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4*4*50, hidden_size) + self.fc2 = nn.Linear(hidden_size, 10) + # skip connection over mid_conv + self.input_switch = InputChoice(n_candidates=2, + n_chosen=1, + key='skip') + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + old_x = x + x = F.relu(self.mid_conv(x)) + zero_x = torch.zeros_like(old_x) + skip_x = self.input_switch([zero_x, old_x]) + x = torch.add(x, skip_x) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4*4*50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) \ No newline at end of file diff --git a/dubhe-tadl/classic_nas/mutator.py b/dubhe-tadl/classic_nas/mutator.py new file mode 100644 index 0000000..02556d7 --- /dev/null +++ b/dubhe-tadl/classic_nas/mutator.py @@ -0,0 +1,260 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +import os +import sys +sys.path.append('..'+ '/' + '..') +import torch + + + +from pytorch.mutables import LayerChoice, InputChoice, MutableScope +from pytorch.mutator import Mutator +import numpy as np +import random +logger = logging.getLogger(__name__) + +NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE" +NNI_PLATFORM = "GPU" +LAYER_CHOICE = "layer_choice" +INPUT_CHOICE = "input_choice" + + +def get_and_apply_next_architecture(model): + """ + Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful, + similar to ``get_next_parameter`` for HPO. + + It will generate search space based on ``model``. + If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for + generating search space for the experiment. + If not, there are still two mode, one is nni experiment mode where users + use ``nnictl`` to start an experiment. The other is standalone mode + where users directly run the trial command, this mode chooses the first + one(s) for each LayerChoice and InputChoice. + + Parameters + ---------- + model : nn.Module + User's model with search space (e.g., LayerChoice, InputChoice) embedded in it. + """ + ClassicMutator(model) + + +class ClassicMutator(Mutator): + """ + This mutator is to apply the architecture chosen from tuner. + It implements the forward function of LayerChoice and InputChoice, + to only activate the chosen ones. + + Parameters + ---------- + model : nn.Module + User's model with search space (e.g., LayerChoice, InputChoice) embedded in it. + """ + + def __init__(self, model,trial_id,selected_path,search_space_path,load_selected_space=False): + super(ClassicMutator, self).__init__(model) + self._chosen_arch = {} + self._search_space = self._generate_search_space() + self.trial_id = trial_id + #if NNI_GEN_SEARCH_SPACE in os.environ: + # dry run for only generating search space + self._dump_search_space(search_space_path) + #sys.exit(0) + + if load_selected_space: + logger.warning("load selected space.") + self._chosen_arch = self.load_selected_space + else: + # get chosen arch from tuner + self._chosen_arch = self.random_generate_chosen() + + + self._generate_selected_space(selected_path) + self.reset() + + def _sample_layer_choice(self, mutable, idx, value, search_space_item): + """ + Convert layer choice to tensor representation. + + Parameters + ---------- + mutable : Mutable + idx : int + Number `idx` of list will be selected. + value : str + The verbose representation of the selected value. + search_space_item : list + The list for corresponding search space. + """ + # doesn't support multihot for layer choice yet + onehot_list = [False] * len(mutable) + assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \ + "Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value) + onehot_list[idx] = True + return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable + + def _sample_input_choice(self, mutable, idx, value, search_space_item): + """ + Convert input choice to tensor representation. + + Parameters + ---------- + mutable : Mutable + idx : int + Number `idx` of list will be selected. + value : str + The verbose representation of the selected value. + search_space_item : list + The list for corresponding search space. + """ + candidate_repr = search_space_item["candidates"] + multihot_list = [False] * mutable.n_candidates + for i, v in zip(idx, value): + assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \ + "Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v) + assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx) + multihot_list[i] = True + return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable + + def sample_search(self): + """ + See :meth:`sample_final`. + """ + return self.sample_final() + + def sample_final(self): + """ + Convert the chosen arch and apply it on model. + """ + assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \ + "Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(), + self._chosen_arch.keys()) + result = dict() + for mutable in self.mutables: + if isinstance(mutable, (LayerChoice, InputChoice)): + assert mutable.key in self._chosen_arch, \ + "Expected '{}' in chosen arch, but not found.".format(mutable.key) + data = self._chosen_arch[mutable.key] + assert isinstance(data, dict) and "_value" in data and "_idx" in data, \ + "'{}' is not a valid choice.".format(data) + if isinstance(mutable, LayerChoice): + result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"], + self._search_space[mutable.key]["_value"]) + elif isinstance(mutable, InputChoice): + result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"], + self._search_space[mutable.key]["_value"]) + elif isinstance(mutable, MutableScope): + logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key) + else: + raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) + return result + + def _standalone_generate_chosen(self): + """ + Generate the chosen architecture for standalone mode, + i.e., choose the first one(s) for LayerChoice and InputChoice. + :: + { key_name: {"_value": "conv1", + "_idx": 0} } + { key_name: {"_value": ["in1"], + "_idx": [0]} } + Returns + ------- + dict + the chosen architecture + """ + chosen_arch = {} + for key, val in self._search_space.items(): + if val["_type"] == LAYER_CHOICE: + choices = val["_value"] + chosen_arch[key] = {"_value": choices[0], "_idx": 0} + elif val["_type"] == INPUT_CHOICE: + choices = val["_value"]["candidates"] + n_chosen = val["_value"]["n_chosen"] + if n_chosen is None: + n_chosen = len(choices) + chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))} + else: + raise ValueError("Unknown key '%s' and value '%s'." % (key, val)) + return chosen_arch + + def random_generate_chosen(self): + """ + Generate the chosen architecture for standalone mode, + i.e., choose the first one(s) for LayerChoice and InputChoice. + :: + { key_name: {"_value": "conv1", + "_idx": 0} } + { key_name: {"_value": ["in1"], + "_idx": [0]} } + Returns + ------- + dict + the chosen architecture + """ + chosen_arch = {} + np.random.seed(self.trial_id) + random.seed(self.trial_id) + for key, val in self._search_space.items(): + if val["_type"] == LAYER_CHOICE: + choices = val["_value"] + + chosen_idx = np.random.randint(len(choices)) + chosen_arch[key] = {"_value": choices[chosen_idx], "_idx": chosen_idx} + elif val["_type"] == INPUT_CHOICE: + choices = val["_value"]["candidates"] + n_chosen = val["_value"]["n_chosen"] + if n_chosen is None: + n_chosen = len(choices) + chosen_idx = random.sample(list(range(n_chosen)),n_chosen) + chosen_arch[key] = {"_value": [choices[idx] for idx in chosen_idx], "_idx": chosen_idx} + else: + raise ValueError("Unknown key '%s' and value '%s'." % (key, val)) + return chosen_arch + + + def _generate_search_space(self): + """ + Generate search space from mutables. + Here is the search space format: + :: + { key_name: {"_type": "layer_choice", + "_value": ["conv1", "conv2"]} } + { key_name: {"_type": "input_choice", + "_value": {"candidates": ["in1", "in2"], + "n_chosen": 1}} } + Returns + ------- + dict + the generated search space + """ + search_space = {} + for mutable in self.mutables: + # for now we only generate flattened search space + if isinstance(mutable, LayerChoice): + key = mutable.key + val = mutable.names + search_space[key] = {"_type": LAYER_CHOICE, "_value": val} + elif isinstance(mutable, InputChoice): + key = mutable.key + search_space[key] = {"_type": INPUT_CHOICE, + "_value": {"candidates": mutable.choose_from, + "n_chosen": mutable.n_chosen}} + elif isinstance(mutable, MutableScope): + logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key) + else: + raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) + + return search_space + + def _dump_search_space(self, file_path): + with open(file_path, "w") as ss_file: + json.dump(self._search_space, ss_file, sort_keys=True, indent=2) + + def _generate_selected_space(self,file_path): + with open(file_path, "w") as ss_file: + json.dump(self._chosen_arch, ss_file, sort_keys=True, indent=2) diff --git a/dubhe-tadl/classic_nas/readme b/dubhe-tadl/classic_nas/readme new file mode 100644 index 0000000..c8e2226 --- /dev/null +++ b/dubhe-tadl/classic_nas/readme @@ -0,0 +1,64 @@ +stage1:python trainer.py --trial_id=1 --model_selected_space_path='./exp/train/2/model_selected_space.json' --search_space_path='./search_space.json' --result_path='./exp/train/2/result.json' +stage2:python selector.py --experiment_dir='./exp' --best_selected_space_path='./best_selected_space.json' +stage3:python retrainer.py --best_checkpoint_dir='experiment_id/' --best_selected_space_path='./best_selected_space.json' --result_path='result.json' +search_space.json: +{ + "first_conv": { + "_type": "layer_choice", + "_value": [ + "conv5x5", + "conv3x3" + ] + }, + "mid_conv": { + "_type": "layer_choice", + "_value": [ + "0", + "1" + ] + }, + "skip": { + "_type": "input_choice", + "_value": { + "candidates": [ + "", + "" + ], + "n_chosen": 1 + } + } +} + + +selected_space.json: +{ + "first_conv": { + "_idx": 0, + "_value": "conv5x5" + }, + "mid_conv": { + "_idx": 0, + "_value": "0" + }, + "skip": { + "_idx": [ + 0 + ], + "_value": [ + "" + ] + } +} + +result.json: +{'type': 'accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 96.73815907059875}} +{'type': 'accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 97.6988382484361}} +{'type': 'accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 98.63717605004469}} +{'type': 'accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 98.72654155495978}} +{'type': 'accuracy', 'result': {'sequence': 5, 'category': 'epoch', 'value': 99.27390527256479}} +{'type': 'accuracy', 'result': {'sequence': 6, 'category': 'epoch', 'value': 99.13985701519213}} +{'type': 'accuracy', 'result': {'sequence': 7, 'category': 'epoch', 'value': 99.3632707774799}} +{'type': 'accuracy', 'result': {'sequence': 8, 'category': 'epoch', 'value': 99.4414655942806}} +{'type': 'accuracy', 'result': {'sequence': 9, 'category': 'epoch', 'value': 99.67605004468275}} +{'type': 'accuracy', 'result': {'sequence': 10, 'category': 'epoch', 'value': 99.74307417336908}} + diff --git a/dubhe-tadl/classic_nas/retrainer.py b/dubhe-tadl/classic_nas/retrainer.py new file mode 100644 index 0000000..0fd1289 --- /dev/null +++ b/dubhe-tadl/classic_nas/retrainer.py @@ -0,0 +1,288 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import logging +import os +import argparse +import logging +import sys +sys.path.append('..'+ '/' + '..') +from collections import OrderedDict + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision import datasets, transforms +from model import Net + +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup +from pytorch.utils import mkdirs +from pytorch.mutables import LayerChoice, InputChoice +from fixed import apply_fixed_architecture + +from mutator import ClassicMutator +from abc import ABC, abstractmethod +from pytorch.retrainer import Retrainer +import numpy as np +import time +import json + +logger = logging.getLogger(__name__) +#logger.setLevel(logging.INFO) + + +class ClassicnasRetrainer(Retrainer): + """ + Classicnas trainer. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + optimizer : Optimizer + The optimizer used for optimizing the model. + epochs : int + Number of epochs planned for training. + dataset_train : Dataset + Dataset for training. Will be split for training weights and architecture weights. + dataset_valid : Dataset + Dataset for testing. + mutator : ClassicMutator + Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + callbacks : list of Callback + list of callbacks to trigger at events. + arc_learning_rate : float + Learning rate of architecture parameters. + unrolled : float + ``True`` if using second order optimization, else first order optimization. + """ + def __init__(self, model, loss, metrics, + optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,checkpoint_dir,trial_id, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None, arc_learning_rate=3.0E-4, unrolled=False): + + + self.model = model + + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + self.epochs = epochs + self.device = device + self.batch_size = batch_size + self.checkpoint_dir =checkpoint_dir + + self.train_loader = torch.utils.data.DataLoader( + datasets.MNIST(dataset_train, train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=batch_size, shuffle=True, **kwargs) + + self.test_loader = torch.utils.data.DataLoader( + datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=1000, shuffle=True, **kwargs) + + + + self.search_space_path = search_space_path + self.selected_space_path =selected_space_path + self.trial_id = trial_id + self.result = {"accuracy": [],"cost_time": 0.} + + def train(self): + + + # t1 = time() + # phase 1. architecture step + #print(self.model.state_dict) + apply_fixed_architecture(self.model, self.selected_space_path) + #print(self.model.state_dict) + # phase 2: child network step + for child_epoch in range(1, self.epochs + 1): + + self.model.train() + for batch_idx, (data, target) in enumerate(self.train_loader): + data, target = data.to(self.device), target.to(self.device) + optimizer.zero_grad() + output = self.model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args['log_interval'] == 0: + logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + child_epoch, batch_idx * len(data), len(self.train_loader.dataset), + 100. * batch_idx / len(self.train_loader), loss.item())) + + test_acc = self.validate() + print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + with open(args['result_path'], "a") as ss_file: + ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n') + self.result['accuracy'].append(test_acc) + + + def validate(self): + self.model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + output = self.model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction='sum').item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(self.test_loader.dataset) + + accuracy = 100. * correct / len(self.test_loader.dataset) + + logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(self.test_loader.dataset), accuracy)) + + return accuracy + + # + # def export(self, file): + # """ + # Override the method to export to file. + + # Parameters + # ---------- + # file : str + # File path to export to. + # """ + # raise NotImplementedError + + + def checkpoint(self): + """ + Override to dump a checkpoint. + """ + if isinstance(self.model, nn.DataParallel): + state_dict = self.model.module.state_dict() + else: + state_dict = self.model.state_dict() + if not os.path.exists(self.checkpoint_dir): + os.makedirs(self.checkpoint_dir) + dest_path = os.path.join(self.checkpoint_dir, f"best_checkpoint_epoch{self.epochs}.pth") + logger.info("Saving model to %s", dest_path) + torch.save(state_dict, dest_path) + + + +def dump_global_result(args,global_result): + with open(args['result_path'], "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=True, indent=2) + +def get_params(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument("--data_dir", type=str, + default='./data', help="data directory") + parser.add_argument("--best_selected_space_path", type=str, + default='./selected_space.json', help="selected_space_path") + parser.add_argument("--search_space_path", type=str, + default='./selected_space.json', help="search_space_path") + parser.add_argument("--result_path", type=str, + default='./result.json', help="result_path") + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument("--hidden_size", type=int, default=512, metavar='N', + help='hidden layer size (default: 512)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--no_cuda', default=False, + help='disables CUDA training') + parser.add_argument('--log_interval', type=int, default=1000, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument("--best_checkpoint_dir",type=str,default="path/to/", + help="Path for saved checkpoints. (default: %(default)s)") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + + args, _ = parser.parse_known_args() + return args + +if __name__ == '__main__': + try: + start=time.time() + + params = vars(get_params()) + args =params + + use_cuda = not args['no_cuda'] and torch.cuda.is_available() + + torch.manual_seed(args['seed']) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + + data_dir = args['data_dir'] + + hidden_size = args['hidden_size'] + + model = Net(hidden_size=hidden_size).to(device) + + optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], + momentum=args['momentum']) + + #mkdirs(args['search_space_path']) + mkdirs(args['best_selected_space_path']) + mkdirs(args['result_path']) + trainer = ClassicnasRetrainer(model, + loss=None, + metrics=None, + optimizer=optimizer, + epochs=args['epochs'], + dataset_train=data_dir, + dataset_valid=data_dir, + search_space_path = args['search_space_path'], + selected_space_path = args['best_selected_space_path'], + checkpoint_dir = args['best_checkpoint_dir'], + trial_id = args['trial_id'], + batch_size=args['batch_size'], + log_frequency=args['log_interval'], + device= device, + unrolled=None, + callbacks=None) + + with open(args['result_path'], "w") as ss_file: + ss_file.write('') + trainer.train() + trainer.checkpoint() + global_result = trainer.result + global_result['cost_time'] = str(time.time() - start) +'s' + #dump_global_result(params,global_result) + except Exception as exception: + logger.exception(exception) + raise \ No newline at end of file diff --git a/dubhe-tadl/classic_nas/selector.py b/dubhe-tadl/classic_nas/selector.py new file mode 100644 index 0000000..c69a1a6 --- /dev/null +++ b/dubhe-tadl/classic_nas/selector.py @@ -0,0 +1,58 @@ +import sys +sys.path.append('../..') +from pytorch.selector import Selector +from pytorch.utils import mkdirs +import shutil +import argparse +import os +import json + +class ClassicnasSelector(Selector): + def __init__(self, args, single_candidate=True): + super().__init__(single_candidate) + self.args = args + + def fit(self): + """ + only one candatite, function passed + """ + train_dir = os.path.join(self.args['experiment_dir'],'train') + max_accuracy = 0 + best_selected_space = '' + for trialId in os.listdir(train_dir): + path= os.path.join(train_dir,trialId,'result','result.json') + max_accuracy_trial = 0 + with open(path,'r') as f: + for line in f: + result_dict = json.loads(line) + accuracy = result_dict["result"]["value"] + if accuracy>max_accuracy_trial: + max_accuracy_trial=accuracy + print(max_accuracy_trial) + if max_accuracy_trial > max_accuracy: + max_accuracy = max_accuracy_trial + best_selected_space = os.path.join(train_dir,trialId,'model_selected_space','model_selected_space.json') + print('best trial id:',trialId) + + shutil.copyfile(best_selected_space,self.args['best_selected_space_path']) + + +def get_params(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument("--experiment_dir", type=str, + default='./experiment_dir', help="data directory") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="selected_space_path") + + args, _ = parser.parse_known_args() + return args + +if __name__ == "__main__": + + params = vars(get_params()) + args =params + mkdirs(args['best_selected_space_path']) + + hpo_selector = ClassicnasSelector(args,single_candidate=False) + hpo_selector.fit() \ No newline at end of file diff --git a/dubhe-tadl/classic_nas/trainer.py b/dubhe-tadl/classic_nas/trainer.py new file mode 100644 index 0000000..b945493 --- /dev/null +++ b/dubhe-tadl/classic_nas/trainer.py @@ -0,0 +1,269 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import logging +import os +import argparse +import logging +import sys +sys.path.append('..'+ '/' + '..') +from collections import OrderedDict + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision import datasets, transforms +from model import Net + +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup +from pytorch.utils import mkdirs +from pytorch.mutables import LayerChoice, InputChoice + +from mutator import ClassicMutator +import numpy as np +import time +import json + +logger = logging.getLogger(__name__) +#logger.setLevel(logging.INFO) + + +class ClassicnasTrainer(Trainer): + """ + Classicnas trainer. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + optimizer : Optimizer + The optimizer used for optimizing the model. + num_epochs : int + Number of epochs planned for training. + dataset_train : Dataset + Dataset for training. Will be split for training weights and architecture weights. + dataset_valid : Dataset + Dataset for testing. + mutator : ClassicMutator + Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + callbacks : list of Callback + list of callbacks to trigger at events. + arc_learning_rate : float + Learning rate of architecture parameters. + unrolled : float + ``True`` if using second order optimization, else first order optimization. + """ + def __init__(self, model, loss, metrics, + optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,trial_id, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None, arc_learning_rate=3.0E-4, unrolled=False): + + + self.model = model + + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + self.epochs = epochs + self.device = device + self.batch_size = batch_size + + self.train_loader = torch.utils.data.DataLoader( + datasets.MNIST(dataset_train, train=True, download=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=batch_size, shuffle=True, **kwargs) + + self.test_loader = torch.utils.data.DataLoader( + datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=1000, shuffle=True, **kwargs) + + + + self.search_space_path = search_space_path + self.selected_space_path =selected_space_path + self.trial_id = trial_id + self.num_epochs = 10 + self.classicmutator=ClassicMutator(self.model,trial_id=self.trial_id,selected_path=self.selected_space_path,search_space_path=self.search_space_path) + + self.result = {"accuracy": [],"cost_time": 0.} + + def train_one_epoch(self, epoch): + + + # t1 = time() + # phase 1. architecture step + self.classicmutator.trial_id = epoch + self.classicmutator._chosen_arch=self.classicmutator.random_generate_chosen() + #print('epoch:',epoch,'\n',self.classicmutator._chosen_arch) + + # phase 2: child network step + for child_epoch in range(1, self.epochs + 1): + + self.model.train() + for batch_idx, (data, target) in enumerate(self.train_loader): + data, target = data.to(self.device), target.to(self.device) + optimizer.zero_grad() + output = self.model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args['log_interval'] == 0: + logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + child_epoch, batch_idx * len(data), len(self.train_loader.dataset), + 100. * batch_idx / len(self.train_loader), loss.item())) + + test_acc = self.validate_one_epoch(epoch) + print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + with open(args['result_path'], "a") as ss_file: + ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n') + self.result['accuracy'].append(test_acc) + + def validate_one_epoch(self, epoch): + self.model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + output = self.model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction='sum').item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(self.test_loader.dataset) + + accuracy = 100. * correct / len(self.test_loader.dataset) + + logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(self.test_loader.dataset), accuracy)) + + return accuracy + def train(self): + """ + Train ``num_epochs``. + Trigger callbacks at the start and the end of each epoch. + + Parameters + ---------- + validate : bool + If ``true``, will do validation every epoch. + """ + for epoch in range(self.num_epochs): + # training + self.train_one_epoch(epoch) + + +def dump_global_result(args,global_result): + with open(args['result_path'], "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=True, indent=2) + +def get_params(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument("--data_dir", type=str, + default='./data', help="data directory") + parser.add_argument("--model_selected_space_path", type=str, + default='./selected_space.json', help="selected_space_path") + parser.add_argument("--search_space_path", type=str, + default='./selected_space.json', help="search_space_path") + parser.add_argument("--result_path", type=str, + default='./model_result.json', help="result_path") + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument("--hidden_size", type=int, default=512, metavar='N', + help='hidden layer size (default: 512)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--no_cuda', default=False, + help='disables CUDA training') + parser.add_argument('--log_interval', type=int, default=1000, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + + args, _ = parser.parse_known_args() + return args + +if __name__ == '__main__': + try: + start=time.time() + + params = vars(get_params()) + args =params + + use_cuda = not args['no_cuda'] and torch.cuda.is_available() + + torch.manual_seed(args['seed']) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + + data_dir = args['data_dir'] + + hidden_size = args['hidden_size'] + + model = Net(hidden_size=hidden_size).to(device) + + optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], + momentum=args['momentum']) + + mkdirs(args['search_space_path']) + mkdirs(args['model_selected_space_path']) + mkdirs(args['result_path']) + trainer = ClassicnasTrainer(model, + loss=None, + metrics=None, + optimizer=optimizer, + epochs=args['epochs'], + dataset_train=data_dir, + dataset_valid=data_dir, + search_space_path = args['search_space_path'], + selected_space_path = args['model_selected_space_path'], + trial_id = args['trial_id'], + batch_size=args['batch_size'], + log_frequency=args['log_interval'], + device= device, + unrolled=None, + callbacks=None) + + with open(args['result_path'], "w") as ss_file: + ss_file.write('') + trainer.train_one_epoch(args['trial_id']) + #trainer.train() + global_result = trainer.result + #global_result['cost_time'] = str(time.time() - start) +'s' + #dump_global_result(params,global_result) + except Exception as exception: + logger.exception(exception) + raise diff --git a/dubhe-tadl/cream/README.md b/dubhe-tadl/cream/README.md new file mode 100644 index 0000000..0dd8bbd --- /dev/null +++ b/dubhe-tadl/cream/README.md @@ -0,0 +1,70 @@ +# Cream of the Crop: Distilling Prioritized Paths For One-Shot Neural Architecture Search + +## 0x01 requirements + +* Install the following requirements: + +``` +future +thop +timm<0.4 +yacs +ptflops==0.6.4 +#tensorboardx +#tensorboard +#opencv-python +#torch-scope +#git+https://github.com/sovrasov/flops-counter.pytorch.git +#git+https://github.com/Tramac/torchscope.git +``` + +* (required) Build and install apex to accelerate the training + (see [yuque](https://www.yuque.com/kcgyxv/ukpea3/mxz5xy)), + a little bit faster than pytorch DistributedDataParallel. + +* Put the imagenet data in `./data` Using the following script: + +``` +cd TADL_DIR/pytorch/cream/ +ln -s /mnt/data . +``` + +## 0x02 Quick Start + +* Run the following script to search an architecture. + +``` +python trainer.py +``` + +* Selector (deprecated) + +``` +python selector.py +``` + +* Train searched architectures. + +> Note: exponential moving average(model_ema) is not available yet. + +``` +python retrainer.py +``` + + diff --git a/dubhe-tadl/cream/algorithms/__init__.py b/dubhe-tadl/cream/algorithms/__init__.py new file mode 100644 index 0000000..95f42fe --- /dev/null +++ b/dubhe-tadl/cream/algorithms/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .trainer import CreamSupernetTrainer +from .mutator import RandomMutator diff --git a/dubhe-tadl/cream/algorithms/mutator.py b/dubhe-tadl/cream/algorithms/mutator.py new file mode 100644 index 0000000..4f24ac9 --- /dev/null +++ b/dubhe-tadl/cream/algorithms/mutator.py @@ -0,0 +1,36 @@ +import torch +import torch.nn.functional as F + +from pytorch.mutator import Mutator +from pytorch.mutables import LayerChoice, InputChoice + +# TODO: This class is duplicate with SPOS. +class RandomMutator(Mutator): + """ + Random mutator that samples a random candidate in the search space each time ``reset()``. + It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior. + """ + + def sample_search(self): + """ + Sample a random candidate. + """ + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + gen_index = torch.randint(high=len(mutable), size=(1, )) + result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() + elif isinstance(mutable, InputChoice): + if mutable.n_chosen is None: + result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() + else: + perm = torch.randperm(mutable.n_candidates) + mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] + result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable + return result + + def sample_final(self): + """ + Same as :meth:`sample_search`. + """ + return self.sample_search() diff --git a/dubhe-tadl/cream/algorithms/trainer.py b/dubhe-tadl/cream/algorithms/trainer.py new file mode 100644 index 0000000..d624fb8 --- /dev/null +++ b/dubhe-tadl/cream/algorithms/trainer.py @@ -0,0 +1,437 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import json +import numpy as np +import torch +import logging + +from copy import deepcopy +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup + +from .utils import accuracy, reduce_metrics + +logger = logging.getLogger(__name__) + + +class CreamSupernetTrainer(Trainer): + """ + This trainer trains a supernet and output prioritized architectures that can be used for other tasks. + + Parameters + ---------- + model : nn.Module + Model with mutables. + loss : callable + Called with logits and targets. Returns a loss tensor. + val_loss : callable + Called with logits and targets for validation only. Returns a loss tensor. + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + train_loader : iterablez + Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. + valid_loader : iterablez + Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. + mutator : Mutator + A mutator object that has been initialized with the model. + batch_size : int + Batch size. + log_frequency : int + Number of mini-batches to log metrics. + meta_sta_epoch : int + start epoch of using meta matching network to pick teacher architecture + update_iter : int + interval of updating meta matching networks + slices : int + batch size of mini training data in the process of training meta matching network + pool_size : int + board size + pick_method : basestring + how to pick teacher network + choice_num : int + number of operations in supernet + sta_num : int + layer number of each stage in supernet (5 stage in supernet) + acc_gap : int + maximum accuracy improvement to omit the limitation of flops + flops_dict : Dict + dictionary of each layer's operations in supernet + flops_fixed : int + flops of fixed part in supernet + local_rank : int + index of current rank + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + """ + + def __init__(self, selected_space, model, loss, val_loss, + optimizer, num_epochs, train_loader, valid_loader, + mutator=None, batch_size=64, log_frequency=None, + meta_sta_epoch=20, update_iter=200, slices=2, + pool_size=10, pick_method='meta', choice_num=6, + sta_num=(4, 4, 4, 4, 4), acc_gap=5, + flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None, result_path=None): + assert torch.cuda.is_available() + super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None, + optimizer, num_epochs, None, None, + batch_size, None, None, log_frequency, callbacks) + self.selected_space = selected_space + self.model = model + self.loss = loss + self.val_loss = val_loss + self.train_loader = train_loader + self.valid_loader = valid_loader + self.log_frequency = log_frequency + self.batch_size = batch_size + self.optimizer = optimizer + self.model = model + self.loss = loss + self.num_epochs = num_epochs + self.meta_sta_epoch = meta_sta_epoch + self.update_iter = update_iter + self.slices = slices + self.pick_method = pick_method + self.pool_size = pool_size + self.local_rank = local_rank + self.choice_num = choice_num + self.sta_num = sta_num + self.acc_gap = acc_gap + self.flops_dict = flops_dict + self.flops_fixed = flops_fixed + + self.current_student_arch = None + self.current_teacher_arch = None + self.main_proc = (local_rank == 0) + self.current_epoch = 0 + + self.prioritized_board = [] + self.result_path = result_path + + # size of prioritized board + def _board_size(self): + return len(self.prioritized_board) + + # select teacher architecture according to the logit difference + def _select_teacher(self): + self._replace_mutator_cand(self.current_student_arch) + + if self.pick_method == 'top1': + meta_value, teacher_cand = 0.5, sorted( + self.prioritized_board, reverse=True)[0][3] + elif self.pick_method == 'meta': + meta_value, cand_idx, teacher_cand = -1000000000, -1, None + for now_idx, item in enumerate(self.prioritized_board): + inputx = item[4] + output = torch.nn.functional.softmax(self.model(inputx), dim=1) + weight = self.model.forward_meta(output - item[5]) + if weight > meta_value: + meta_value = weight + cand_idx = now_idx + teacher_cand = self.prioritized_board[cand_idx][3] + assert teacher_cand is not None + meta_value = torch.nn.functional.sigmoid(-weight) + else: + raise ValueError('Method Not supported') + + return meta_value, teacher_cand + + # check whether to update prioritized board + def _isUpdateBoard(self, prec1, flops): + if self.current_epoch <= self.meta_sta_epoch: + return False + + if len(self.prioritized_board) < self.pool_size: + return True + + if prec1 > self.prioritized_board[-1][1] + self.acc_gap: + return True + + if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]: + return True + + return False + + # update prioritized board + def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops): + if self._isUpdateBoard(prec1, flops): + val_prec1 = prec1 + training_data = deepcopy(inputs[:self.slices].detach()) + if len(self.prioritized_board) == 0: + features = deepcopy(outputs[:self.slices].detach()) + else: + features = deepcopy(teacher_output[:self.slices].detach()) + + self.prioritized_board.append( + (val_prec1, + prec1, + flops, + self.current_student_arch, + training_data, + torch.nn.functional.softmax( + features, + dim=1))) + self.prioritized_board = sorted( + self.prioritized_board, reverse=True) + + if len(self.prioritized_board) > self.pool_size: + self.prioritized_board = sorted( + self.prioritized_board, reverse=True) + del self.prioritized_board[-1] + + # only update student network weights + def _update_student_weights_only(self, grad_1): + for weight, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1): + weight.grad = grad_item + torch.nn.utils.clip_grad_norm_( + self.model.module.rand_parameters(self.current_student_arch), 1) + self.optimizer.step() + for weight, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1): + del weight.grad + + # only update meta networks weights + def _update_meta_weights_only(self, teacher_cand, grad_teacher): + for weight, grad_item in zip(self.model.module.rand_parameters( + teacher_cand, self.pick_method == 'meta'), grad_teacher): + weight.grad = grad_item + + # clip gradients + torch.nn.utils.clip_grad_norm_( + self.model.module.rand_parameters( + self.current_student_arch, self.pick_method == 'meta'), 1) + + self.optimizer.step() + for weight, grad_item in zip(self.model.module.rand_parameters( + teacher_cand, self.pick_method == 'meta'), grad_teacher): + del weight.grad + + # simulate sgd updating + def _simulate_sgd_update(self, w, g, optimizer): + return g * optimizer.param_groups[-1]['lr'] + w + + # split training images into several slices + def _get_minibatch_input(self, input): + slice = self.slices + x = deepcopy(input[:slice].clone().detach()) + return x + + # calculate 1st gradient of student architectures + def _calculate_1st_gradient(self, kd_loss): + self.optimizer.zero_grad() + grad = torch.autograd.grad( + kd_loss, + self.model.module.rand_parameters(self.current_student_arch), + create_graph=True) + return grad + + # calculate 2nd gradient of meta networks + def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight): + self.optimizer.zero_grad() + grad_student_val = torch.autograd.grad( + validation_loss, + self.model.module.rand_parameters(self.current_student_arch), + retain_graph=True) + + grad_teacher = torch.autograd.grad( + students_weight[0], + self.model.module.rand_parameters( + teacher_cand, + self.pick_method == 'meta'), + grad_outputs=grad_student_val) + return grad_teacher + + # forward training data + def _forward_training(self, x, meta_value): + self._replace_mutator_cand(self.current_student_arch) + output = self.model(x) + + with torch.no_grad(): + self._replace_mutator_cand(self.current_teacher_arch) + teacher_output = self.model(x) + soft_label = torch.nn.functional.softmax(teacher_output, dim=1) + + kd_loss = meta_value * \ + self._cross_entropy_loss_with_soft_target(output, soft_label) + return kd_loss + + # calculate soft target loss + def _cross_entropy_loss_with_soft_target(self, pred, soft_target): + logsoftmax = torch.nn.LogSoftmax() + return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) + + # forward validation data + def _forward_validation(self, input, target): + slice = self.slices + x = input[slice:slice * 2].clone() + + self._replace_mutator_cand(self.current_student_arch) + output_2 = self.model(x) + + validation_loss = self.loss(output_2, target[slice:slice * 2]) + return validation_loss + + def _isUpdateMeta(self, batch_idx): + isUpdate = True + isUpdate &= (self.current_epoch > self.meta_sta_epoch) + isUpdate &= (batch_idx > 0) + isUpdate &= (batch_idx % self.update_iter == 0) + isUpdate &= (self._board_size() > 0) + return isUpdate + + def _replace_mutator_cand(self, cand): + self.mutator._cache = cand + + # update meta matching networks + def _run_update(self, input, target, batch_idx): + if self._isUpdateMeta(batch_idx): + x = self._get_minibatch_input(input) + + meta_value, teacher_cand = self._select_teacher() + + kd_loss = self._forward_training(x, meta_value) + + # calculate 1st gradient + grad_1st = self._calculate_1st_gradient(kd_loss) + + # simulate updated student weights + students_weight = [ + self._simulate_sgd_update( + p, grad_item, self.optimizer) for p, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1st)] + + # update student weights + self._update_student_weights_only(grad_1st) + + validation_loss = self._forward_validation(input, target) + + # calculate 2nd gradient + grad_teacher = self._calculate_2nd_gradient(validation_loss, + teacher_cand, + students_weight) + + # update meta matching networks + self._update_meta_weights_only(teacher_cand, grad_teacher) + + # delete internal variants + del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight + + def _get_cand_flops(self, cand): + flops = 0 + for block_id, block in enumerate(cand): + if block == 'LayerChoice1' or block_id == 'LayerChoice23': + continue + for idx, choice in enumerate(cand[block]): + flops += self.flops_dict[block_id][idx] * (1 if choice else 0) + return flops + self.flops_fixed + + def train_one_epoch(self, epoch): + self.current_epoch = epoch + meters = AverageMeterGroup() + self.steps_per_epoch = len(self.train_loader) + for step, (input_data, target) in enumerate(self.train_loader): + self.mutator.reset() + self.current_student_arch = self.mutator._cache + + input_data, target = input_data.cuda(), target.cuda() + + # calculate flops of current architecture + cand_flops = self._get_cand_flops(self.mutator._cache) + + # update meta matching network + self._run_update(input_data, target, step) + + if self._board_size() > 0: + # select teacher architecture + meta_value, teacher_cand = self._select_teacher() + self.current_teacher_arch = teacher_cand + + # forward supernet + if self._board_size() == 0 or epoch <= self.meta_sta_epoch: + self._replace_mutator_cand(self.current_student_arch) + output = self.model(input_data) + + loss = self.loss(output, target) + kd_loss, teacher_output, teacher_cand = None, None, None + else: + self._replace_mutator_cand(self.current_student_arch) + output = self.model(input_data) + + gt_loss = self.loss(output, target) + + with torch.no_grad(): + self._replace_mutator_cand(self.current_teacher_arch) + teacher_output = self.model(input_data).detach() + + soft_label = torch.nn.functional.softmax(teacher_output, dim=1) + kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label) + + loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2 + + # update network + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # update metrics + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} + metrics = reduce_metrics(metrics) + meters.update(metrics) + + # update prioritized board + self._update_prioritized_board(input_data, + teacher_output, + output, + metrics['prec1'], + cand_flops) + + if self.main_proc and ( + step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): + logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, + step + 1, len(self.train_loader), meters) + + arch_list = [] + # if self.main_proc and self.num_epochs == epoch + 1: + for idx, i in enumerate(self.prioritized_board): + # logger.info("prioritized_board: No.%s %s", idx, i[:4]) + if idx == 0: + for arch in list(i[3].values()): + _ = arch.numpy() + _ = np.where(_)[0].tolist() + arch_list.append(_) + + if len(arch_list) > 0: + with open(self.selected_space, "w") as f: + print("dump selected space.") + json.dump({'selected_space': arch_list}, f) + + def validate_one_epoch(self, epoch): + self.model.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (x, y) in enumerate(self.valid_loader): + self.mutator.reset() + logits = self.model(x) + loss = self.val_loss(logits, y) + prec1, prec5 = accuracy(logits, y, topk=(1, 5)) + metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} + metrics = reduce_metrics(metrics) + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.valid_loader), meters) + # print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', + # 'value': metrics["prec1"]}}) + if self.result_path is not None: + with open(self.result_path, "a") as ss_file: + ss_file.write(json.dumps( + {'type': 'Accuracy', + 'result': {'sequence': epoch, + 'category': 'epoch', + 'value': metrics["prec1"]}}) + '\n') diff --git a/dubhe-tadl/cream/algorithms/utils.py b/dubhe-tadl/cream/algorithms/utils.py new file mode 100644 index 0000000..6edaf35 --- /dev/null +++ b/dubhe-tadl/cream/algorithms/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import os +import torch +import torch.distributed as dist + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(1.0 / batch_size)) + return res + + +def reduce_metrics(metrics): + return {k: reduce_tensor(v).item() for k, v in metrics.items()} + + +def reduce_tensor(tensor): + rt = torch.sum(tensor) + # rt = tensor.clone() + # dist.all_reduce(rt, op=dist.ReduceOp.SUM) + # rt /= float(os.environ["WORLD_SIZE"]) + return rt diff --git a/dubhe-tadl/cream/lib/config.py b/dubhe-tadl/cream/lib/config.py new file mode 100644 index 0000000..7d0cb05 --- /dev/null +++ b/dubhe-tadl/cream/lib/config.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from yacs.config import CfgNode as CN + +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +__C = CN() + +cfg = __C + +__C.AUTO_RESUME = True +__C.DATA_DIR = './data/imagenet' +__C.MODEL = 'cream' +__C.RESUME_PATH = './experiments/ckps/resume.pth.tar' +__C.SAVE_PATH = './experiments/ckps/' +__C.SEED = 42 +__C.LOG_INTERVAL = 50 +__C.RECOVERY_INTERVAL = 0 +__C.WORKERS = 4 +__C.NUM_GPU = 1 +__C.SAVE_IMAGES = False +__C.AMP = False +__C.ACC_GAP = 5 +__C.OUTPUT = 'output/path/' +__C.EVAL_METRICS = 'prec1' +__C.TTA = 0 # Test or inference time augmentation +__C.LOCAL_RANK = 0 +__C.VERBOSE = False + +# dataset configs +__C.DATASET = CN() +__C.DATASET.NUM_CLASSES = 1000 +__C.DATASET.IMAGE_SIZE = 224 # image patch size +__C.DATASET.INTERPOLATION = 'bilinear' # Image resize interpolation type +__C.DATASET.BATCH_SIZE = 32 # batch size +__C.DATASET.NO_PREFECHTER = False +__C.DATASET.PIN_MEM = True +__C.DATASET.VAL_BATCH_MUL = 4 + + +# model configs +__C.NET = CN() +__C.NET.SELECTION = 14 +__C.NET.GP = 'avg' # type of global pool ["avg", "max", "avgmax", "avgmaxc"] +__C.NET.DROPOUT_RATE = 0.0 # dropout rate +__C.NET.INPUT_ARCH = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + +# model ema parameters +__C.NET.EMA = CN() +__C.NET.EMA.USE = True +__C.NET.EMA.FORCE_CPU = False # force model ema to be tracked on CPU +__C.NET.EMA.DECAY = 0.9998 + +# optimizer configs +__C.OPT = 'sgd' +__C.OPT_EPS = 1e-2 +__C.MOMENTUM = 0.9 +__C.WEIGHT_DECAY = 1e-4 +__C.OPTIMIZER = CN() +__C.OPTIMIZER.NAME = 'sgd' +__C.OPTIMIZER.MOMENTUM = 0.9 +__C.OPTIMIZER.WEIGHT_DECAY = 1e-3 + +# scheduler configs +__C.SCHED = 'sgd' +__C.LR_NOISE = None +__C.LR_NOISE_PCT = 0.67 +__C.LR_NOISE_STD = 1.0 +__C.WARMUP_LR = 1e-4 +__C.MIN_LR = 1e-5 +__C.EPOCHS = 200 +__C.START_EPOCH = None +__C.DECAY_EPOCHS = 30.0 +__C.WARMUP_EPOCHS = 3 +__C.COOLDOWN_EPOCHS = 10 +__C.PATIENCE_EPOCHS = 10 +__C.DECAY_RATE = 0.1 +__C.LR = 1e-2 +__C.META_LR = 1e-4 + +# data augmentation parameters +__C.AUGMENTATION = CN() +__C.AUGMENTATION.AA = 'rand-m9-mstd0.5' +__C.AUGMENTATION.COLOR_JITTER = 0.4 +__C.AUGMENTATION.RE_PROB = 0.2 # random erase prob +__C.AUGMENTATION.RE_MODE = 'pixel' # random erase mode +__C.AUGMENTATION.MIXUP = 0.0 # mixup alpha +__C.AUGMENTATION.MIXUP_OFF_EPOCH = 0 # turn off mixup after this epoch +__C.AUGMENTATION.SMOOTHING = 0.1 # label smoothing parameters + +# batch norm parameters (only works with gen_efficientnet based models +# currently) +__C.BATCHNORM = CN() +__C.BATCHNORM.SYNC_BN = False +__C.BATCHNORM.BN_TF = False +__C.BATCHNORM.BN_MOMENTUM = 0.1 # batchnorm momentum override +__C.BATCHNORM.BN_EPS = 1e-5 # batchnorm eps override + +# supernet training hyperparameters +__C.SUPERNET = CN() +__C.SUPERNET.UPDATE_ITER = 1300 +__C.SUPERNET.SLICE = 4 +__C.SUPERNET.POOL_SIZE = 10 +__C.SUPERNET.RESUNIT = False +__C.SUPERNET.DIL_CONV = False +__C.SUPERNET.UPDATE_2ND = True +__C.SUPERNET.FLOPS_MAXIMUM = 600 +__C.SUPERNET.FLOPS_MINIMUM = 0 +__C.SUPERNET.PICK_METHOD = 'meta' # pick teacher method +__C.SUPERNET.META_STA_EPOCH = 20 # start using meta picking method +__C.SUPERNET.HOW_TO_PROB = 'pre_prob' # sample method +__C.SUPERNET.PRE_PROB = (0.05, 0.2, 0.05, 0.5, 0.05, + 0.15) # sample prob in 'pre_prob' diff --git a/dubhe-tadl/cream/lib/core/retrain.py b/dubhe-tadl/cream/lib/core/retrain.py new file mode 100644 index 0000000..a118941 --- /dev/null +++ b/dubhe-tadl/cream/lib/core/retrain.py @@ -0,0 +1,129 @@ +import os +import time +import timm +import torch +import torchvision + +from collections import OrderedDict + +from ..utils.util import AverageMeter, accuracy, reduce_tensor + + +def train_epoch( + epoch, model, loader, optimizer, loss_fn, args, + lr_scheduler=None, saver=None, output_dir='', use_amp=False, + model_ema=None, logger=None, writer=None, local_rank=0): + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + losses_m = AverageMeter() + prec1_m = AverageMeter() + prec5_m = AverageMeter() + + model.train() + + end = time.time() + last_idx = len(loader) - 1 + num_updates = epoch * len(loader) + optimizer.zero_grad() + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + data_time_m.update(time.time() - end) + + input = input.cuda() + target = target.cuda() + output = model(input) + + loss = loss_fn(output, target) + + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + if args.num_gpu > 1: + reduced_loss = reduce_tensor(loss.data, args.num_gpu) + prec1 = reduce_tensor(prec1, args.num_gpu) + prec5 = reduce_tensor(prec5, args.num_gpu) + else: + reduced_loss = loss.data + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + if model_ema is not None: + model_ema.update(model) + num_updates += 1 + + batch_time_m.update(time.time() - end) + if last_batch or batch_idx % args.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + + if local_rank == 0: + logger.info( + 'Train: {} [{:>4d}/{}] ' + 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e}' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, + len(loader), + loss=losses_m, + top1=prec1_m, + top5=prec5_m, + batch_time=batch_time_m, + rate=input.size(0) * args.num_gpu / batch_time_m.val, + rate_avg=input.size(0) * args.num_gpu / batch_time_m.avg, + lr=lr, + data_time=data_time_m)) + + + # writer.add_scalar( + # 'Loss/train', prec1_m.avg, epoch * len(loader) + batch_idx) + # writer.add_scalar( + # 'Accuracy/train', prec1_m.avg, epoch * len(loader) + batch_idx) + # writer.add_scalar( + # 'Learning_Rate', + # optimizer.param_groups[0]['lr'], epoch * len(loader) + batch_idx) + + if args.save_images and output_dir: + torchvision.utils.save_image( + input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, normalize=True) + + if saver is not None and args.recovery_interval and ( + last_batch or (batch_idx + 1) % args.recovery_interval == 0): + if int(timm.__version__[2]) >= 3: + saver.save_recovery( + epoch, + batch_idx=batch_idx) + else: + saver.save_recovery( + model, + optimizer, + args, + epoch, + model_ema=model_ema, + use_amp=use_amp, + batch_idx=batch_idx) + + if lr_scheduler is not None: + lr_scheduler.step_update( + num_updates=num_updates, + metric=losses_m.avg) + + end = time.time() + # end for + + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() + + return OrderedDict([('loss', losses_m.avg)]) diff --git a/dubhe-tadl/cream/lib/core/test.py b/dubhe-tadl/cream/lib/core/test.py new file mode 100644 index 0000000..0e5e574 --- /dev/null +++ b/dubhe-tadl/cream/lib/core/test.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import time +import torch +import json +from collections import OrderedDict +from ..utils.util import AverageMeter, accuracy, reduce_tensor + + +def validate(epoch, model, loader, loss_fn, args, log_suffix='', + logger=None, writer=None, local_rank=0,result_path=None): + batch_time_m = AverageMeter() + losses_m = AverageMeter() + prec1_m = AverageMeter() + prec5_m = AverageMeter() + + model.eval() + + end = time.time() + last_idx = len(loader) - 1 + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + + output = model(input) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold( + 0, + reduce_factor, + reduce_factor).mean( + dim=2) + target = target[0:target.size(0):reduce_factor] + + loss = loss_fn(output, target) + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + if args.num_gpu > 1: + reduced_loss = reduce_tensor(loss.data, args.num_gpu) + prec1 = reduce_tensor(prec1, args.num_gpu) + prec5 = reduce_tensor(prec5, args.num_gpu) + else: + reduced_loss = loss.data + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + batch_time_m.update(time.time() - end) + end = time.time() + if local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): + log_name = 'Test' + log_suffix + logger.info( + '{0}: [{1:>4d}/{2}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, + batch_time=batch_time_m, loss=losses_m, + top1=prec1_m, top5=prec5_m)) + + # print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', 'value': prec1_m.val}}) + + + if result_path is not None: + with open(result_path, "a") as ss_file: + ss_file.write(json.dumps( + {'type': 'Accuracy', + 'result': {'sequence': epoch, + 'category': 'epoch', + 'value': prec1_m.val}}) + '\n') + + + # writer.add_scalar( + # 'Loss' + log_suffix + '/vaild', + # prec1_m.avg, + # epoch * len(loader) + batch_idx) + # writer.add_scalar( + # 'Accuracy' + + # log_suffix + + # '/vaild', + # prec1_m.avg, + # epoch * + # len(loader) + + # batch_idx) + + metrics = OrderedDict( + [('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) + + return metrics diff --git a/dubhe-tadl/cream/lib/models/blocks/__init__.py b/dubhe-tadl/cream/lib/models/blocks/__init__.py new file mode 100644 index 0000000..2d0be66 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/blocks/__init__.py @@ -0,0 +1,2 @@ +from .residual_block import get_Bottleneck, get_BasicBlock +from .inverted_residual_block import InvertedResidual \ No newline at end of file diff --git a/dubhe-tadl/cream/lib/models/blocks/inverted_residual_block.py b/dubhe-tadl/cream/lib/models/blocks/inverted_residual_block.py new file mode 100644 index 0000000..2f501b5 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/blocks/inverted_residual_block.py @@ -0,0 +1,113 @@ +# This file is downloaded from +# https://github.com/rwightman/pytorch-image-models + +import torch.nn as nn + +from timm.models.layers import create_conv2d +from timm.models.efficientnet_blocks import make_divisible, resolve_se_args, \ + SqueezeExcite, drop_path + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__( + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + pad_type='', + act_layer=nn.ReLU, + noskip=False, + exp_ratio=1.0, + exp_kernel_size=1, + pw_kernel_size=1, + se_ratio=0., + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + conv_kwargs=None, + drop_path_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d( + in_chs, + mid_chs, + exp_kernel_size, + padding=pad_type, + **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None + + # Point-wise linear projection + self.conv_pwl = create_conv2d( + mid_chs, + out_chs, + pw_kernel_size, + padding=pad_type, + **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + info = dict( + module='conv_pwl', + hook_type='forward_pre', + num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict( + module='', + hook_type='', + num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + + return x diff --git a/dubhe-tadl/cream/lib/models/blocks/residual_block.py b/dubhe-tadl/cream/lib/models/blocks/residual_block.py new file mode 100644 index 0000000..75892ee --- /dev/null +++ b/dubhe-tadl/cream/lib/models/blocks/residual_block.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=True) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + + def __init__(self, inplanes, planes, stride=1, expansion=4): + super(Bottleneck, self).__init__() + planes = int(planes / expansion) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=True) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, + planes * expansion, + kernel_size=1, + bias=True) + self.bn3 = nn.BatchNorm2d(planes * expansion) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + self.expansion = expansion + if inplanes != planes * self.expansion: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * self.expansion, + kernel_size=1, stride=stride, bias=True), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def get_Bottleneck(in_c, out_c, stride): + return Bottleneck(in_c, out_c, stride=stride) + + +def get_BasicBlock(in_c, out_c, stride): + return BasicBlock(in_c, out_c, stride=stride) diff --git a/dubhe-tadl/cream/lib/models/builders/build_childnet.py b/dubhe-tadl/cream/lib/models/builders/build_childnet.py new file mode 100644 index 0000000..716f1c8 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/builders/build_childnet.py @@ -0,0 +1,182 @@ +from ...utils.util import * + +from collections import OrderedDict +from timm.models.efficientnet_blocks import * + + +class ChildNetBuilder: + def __init__( + self, + channel_multiplier=1.0, + channel_divisor=8, + channel_min=None, + output_stride=32, + pad_type='', + act_layer=None, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + drop_path_rate=0., + feature_location='', + verbose=False, + logger=None): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_path_rate = drop_path_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + self.in_chs = None + self.features = OrderedDict() + self.logger = logger + + def _round_channels(self, chs): + return round_channels( + chs, + self.channel_multiplier, + self.channel_divisor, + self.channel_min) + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' InvertedResidual {}, Args: {}'.format( + block_idx, str(ba))) + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' DepthwiseSeparable {}, Args: {}'.format( + block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'cn': + if self.verbose: + self.logger.info( + ' ConvBnAct {}, Args: {}'.format( + block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + self.logger.info( + 'Building model trunk with %d stages...' % + len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some + # conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + self.logger.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + extract_features = '' # No features extracted + if self.verbose: + self.logger.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + do_extract = False + if self.feature_location == 'pre_pwl': + if last_block: + next_stage_idx = stage_idx + 1 + if next_stage_idx >= len(model_block_args): + do_extract = True + else: + do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 + elif self.feature_location == 'post_exp': + if block_args['stride'] > 1 or (last_stack and last_block): + do_extract = True + if do_extract: + extract_features = self.feature_location + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + if self.verbose: + self.logger.info( + ' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride)) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block( + block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature + # extraction + if extract_features: + feature_module = block.feature_module(extract_features) + if feature_module: + feature_module = 'blocks.{}.{}.'.format( + stage_idx, block_idx) + feature_module + feature_channels = block.feature_channels(extract_features) + self.features[feature_idx] = dict( + name=feature_module, + num_chs=feature_channels + ) + feature_idx += 1 + + # incr global block idx (across all stacks) + total_block_idx += 1 + stages.append(nn.Sequential(*blocks)) + return stages diff --git a/dubhe-tadl/cream/lib/models/builders/build_supernet.py b/dubhe-tadl/cream/lib/models/builders/build_supernet.py new file mode 100644 index 0000000..12d4ab9 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/builders/build_supernet.py @@ -0,0 +1,214 @@ +from copy import deepcopy + +from ...utils.builder_util import modify_block_args +from ..blocks import get_Bottleneck, InvertedResidual + +from timm.models.efficientnet_blocks import * + +from pytorch.mutables import LayerChoice + +class SuperNetBuilder: + """ Build Trunk Blocks + """ + + def __init__( + self, + choices, + channel_multiplier=1.0, + channel_divisor=8, + channel_min=None, + output_stride=32, + pad_type='', + act_layer=None, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + drop_path_rate=0., + feature_location='', + verbose=False, + resunit=False, + dil_conv=False, + logger=None): + + # dict + # choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + self.choices = [[x, y] for x in choices['kernel_size'] + for y in choices['exp_ratio']] + self.choices_num = len(self.choices) - 1 + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_path_rate = drop_path_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + self.resunit = resunit + self.dil_conv = dil_conv + self.logger = logger + + # state updated during build, consumed by model + self.in_chs = None + + def _round_channels(self, chs): + return round_channels( + chs, + self.channel_multiplier, + self.channel_divisor, + self.channel_min) + + def _make_block( + self, + ba, + choice_idx, + block_idx, + block_count, + resunit=False, + dil_conv=False): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input + # filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' InvertedResidual {}, Args: {}'.format( + block_idx, str(ba))) + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' DepthwiseSeparable {}, Args: {}'.format( + block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'cn': + if self.verbose: + self.logger.info( + ' ConvBnAct {}, Args: {}'.format( + block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + if choice_idx == self.choice_num - 1: + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + self.logger.info('Building model trunk with %d stages...' % len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + self.logger.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + # blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + if self.verbose: + self.logger.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + + if stage_idx==0 or stage_idx==6: + self.choice_num = 1 + else: + self.choice_num = len(self.choices) + + if self.dil_conv: + self.choice_num += 2 + + choice_blocks = [] + block_args_copy = deepcopy(block_args) + if self.choice_num == 1: + # create the block + block = self._make_block(block_args, 0, total_block_idx, total_block_count) + choice_blocks.append(block) + else: + for choice_idx, choice in enumerate(self.choices): + # create the block + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, choice[0], choice[1]) + block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count) + choice_blocks.append(block) + if self.dil_conv: + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, 3, 0) + block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count, + resunit=self.resunit, dil_conv=self.dil_conv) + choice_blocks.append(block) + + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, 5, 0) + block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count, + resunit=self.resunit, dil_conv=self.dil_conv) + choice_blocks.append(block) + + if self.resunit: + block = get_Bottleneck(block.conv_pw.in_channels, + block.conv_pwl.out_channels, + block.conv_dw.stride[0]) + choice_blocks.append(block) + + choice_block = LayerChoice(choice_blocks) + stages.append(choice_block) + # create the block + # block = self._make_block(block_args, total_block_idx, total_block_count) + total_block_idx += 1 # incr global block idx (across all stacks) + + # stages.append(blocks) + return stages diff --git a/dubhe-tadl/cream/lib/models/structures/childnet.py b/dubhe-tadl/cream/lib/models/structures/childnet.py new file mode 100644 index 0000000..b8bdfe2 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/structures/childnet.py @@ -0,0 +1,145 @@ +from ...utils.builder_util import * +from ..builders.build_childnet import * + +from timm.models.layers import SelectAdaptivePool2d +from timm.models.layers.activations import hard_sigmoid + + +class ChildNet(nn.Module): + + def __init__( + self, + block_args, + num_classes=1000, + in_chans=3, + stem_size=16, + num_features=1280, + head_bias=True, + channel_multiplier=1.0, + pad_type='', + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + global_pool='avg', + logger=None, + verbose=False): + super(ChildNet, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + self.logger = logger + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d( + self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = ChildNetBuilder( + channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, verbose=verbose) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + # self.blocks = builder(self._in_chs, block_args) + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = create_conv2d( + self._in_chs, + self.num_features, + 1, + padding=pad_type, + bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear( + self.num_features * + self.global_pool.feat_mult(), + self.num_classes) + + efficientnet_init_weights(self) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), + num_classes) if self.num_classes else None + + def forward_features(self, x): + # architecture = [[0], [], [], [], [], [0]] + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def gen_childnet(arch_list, arch_def, **kwargs): + # arch_list = [[0], [], [], [], [], [0]] + choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + choices_list = [[x, y] for x in choices['kernel_size'] + for y in choices['exp_ratio']] + + num_features = 1280 + + # act_layer = HardSwish + act_layer = Swish + + new_arch = [] + # change to child arch_def + for i, (layer_choice, layer_arch) in enumerate(zip(arch_list, arch_def)): + if len(layer_arch) == 1: + new_arch.append(layer_arch) + continue + else: + new_layer = [] + for j, (block_choice, block_arch) in enumerate( + zip(layer_choice, layer_arch)): + kernel_size, exp_ratio = choices_list[block_choice] + elements = block_arch.split('_') + block_arch = block_arch.replace( + elements[2], 'k{}'.format(str(kernel_size))) + block_arch = block_arch.replace( + elements[4], 'e{}'.format(str(exp_ratio))) + new_layer.append(block_arch) + new_arch.append(new_layer) + + model_kwargs = dict( + block_args=decode_arch_def(new_arch), + num_features=num_features, + stem_size=16, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict( + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + reduce_mid=True, + divisor=8), + **kwargs, + ) + model = ChildNet(**model_kwargs) + return model diff --git a/dubhe-tadl/cream/lib/models/structures/supernet.py b/dubhe-tadl/cream/lib/models/structures/supernet.py new file mode 100644 index 0000000..a0d5d27 --- /dev/null +++ b/dubhe-tadl/cream/lib/models/structures/supernet.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +from ...utils.builder_util import * +from ...utils.search_structure_supernet import * +from ...utils.op_by_layer_dict import flops_op_dict +from ..builders.build_supernet import * + +from timm.models.layers import SelectAdaptivePool2d +from timm.models.layers.activations import hard_sigmoid + + +class SuperNet(nn.Module): + + def __init__( + self, + block_args, + choices, + num_classes=1000, + in_chans=3, + stem_size=16, + num_features=1280, + head_bias=True, + channel_multiplier=1.0, + pad_type='', + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + slice=4, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + logger=None, + norm_kwargs=None, + global_pool='avg', + resunit=False, + dil_conv=False, + verbose=False): + super(SuperNet, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + self.logger = logger + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d( + self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = SuperNetBuilder( + choices, + channel_multiplier, + 8, + None, + 32, + pad_type, + act_layer, + se_kwargs, + norm_layer, + norm_kwargs, + drop_path_rate, + verbose=verbose, + resunit=resunit, + dil_conv=dil_conv, + logger=self.logger) + blocks = builder(self._in_chs, block_args) + self.blocks = nn.Sequential(*blocks) + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = create_conv2d( + self._in_chs, + self.num_features, + 1, + padding=pad_type, + bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear( + self.num_features * + self.global_pool.feat_mult(), + self.num_classes) + + self.meta_layer = nn.Linear(self.num_classes * slice, 1) + efficientnet_init_weights(self) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), + num_classes) if self.num_classes else None + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + def forward_meta(self, features): + return self.meta_layer(features.view(1, -1)) + + def rand_parameters(self, architecture, meta=False): + for name, param in self.named_parameters(recurse=True): + if 'meta' in name and meta: + yield param + elif 'blocks' not in name and 'meta' not in name and (not meta): + yield param + + if not meta: + for layer, layer_arch in zip(self.blocks, architecture): + for blocks, arch in zip(layer, layer_arch): + if arch == -1: + continue + for name, param in blocks[arch].named_parameters( + recurse=True): + yield param + + +class Classifier(nn.Module): + def __init__(self, num_classes=1000): + super(Classifier, self).__init__() + self.classifier = nn.Linear(num_classes, num_classes) + + def forward(self, x): + return self.classifier(x) + + +def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs): + choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + + num_features = 1280 + + # act_layer = HardSwish + act_layer = Swish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_se0.25'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', + 'ir_r1_k3_s1_e4_c24_se0.25'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', + 'ir_r2_k3_s1_e4_c80_se0.25'], + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25'], + # stage 5, 14x14in + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25', + 'ir_r1_k5_s2_e6_c192_se0.25'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c320_se0.25'], + ] + + sta_num, arch_def, resolution = search_for_layer( + flops_op_dict, arch_def, flops_minimum, flops_maximum) + + if sta_num is None or arch_def is None or resolution is None: + raise ValueError('Invalid FLOPs Settings') + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + choices=choices, + num_features=num_features, + stem_size=16, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict( + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + reduce_mid=True, + divisor=8), + **kwargs, + ) + model = SuperNet(**model_kwargs) + return model, sta_num, resolution, arch_def diff --git a/dubhe-tadl/cream/lib/utils/builder_util.py b/dubhe-tadl/cream/lib/utils/builder_util.py new file mode 100644 index 0000000..c763503 --- /dev/null +++ b/dubhe-tadl/cream/lib/utils/builder_util.py @@ -0,0 +1,270 @@ +import re +import math +import torch.nn as nn + +from copy import deepcopy + +from timm.utils import * +from timm.models.layers.activations import Swish +from timm.models.layers import CondConv2d, get_condconv_initializer + + +def parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def decode_arch_def( + arch_def, + depth_multiplier=1.0, + depth_trunc='ceil', + experts_multiplier=1): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + arch_args.append( + scale_stage_depth( + stack_args, + repeats, + depth_multiplier, + depth_trunc)) + return arch_args + + +def modify_block_args(block_args, kernel_size, exp_ratio): + block_type = block_args['block_type'] + if block_type == 'cn': + block_args['kernel_size'] = kernel_size + elif block_type == 'er': + block_args['exp_kernel_size'] = kernel_size + else: + block_args['dw_kernel_size'] = kernel_size + + if block_type == 'ir' or block_type == 'er': + block_args['exp_ratio'] = exp_ratio + return block_args + + +def decode_block_str(block_str): + """ Decode block definition string + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they + # grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = nn.ReLU + elif v == 'r6': + value = nn.ReLU6 + elif v == 'sw': + value = Swish + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be + # used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1 + # FIXME hack to deal with in_chs issue in TPU def + fake_in_chs = int(options['fc']) if 'fc' in options else 0 + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def scale_stage_depth( + stack_args, + repeats, + depth_multiplier=1.0, + depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as + # long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every + # stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch + # definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None): + """ Weight initialization as per Tensorflow official implementations. + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_( + 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + if n in last_bn: + m.weight.data.zero_() + m.bias.data.zero_() + else: + m.weight.data.fill_(1.0) + m.bias.data.zero_() + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def efficientnet_init_weights( + model: nn.Module, + init_fn=None, + zero_gamma=False): + last_bn = [] + if zero_gamma: + prev_n = '' + for n, m in model.named_modules(): + if isinstance(m, nn.BatchNorm2d): + if ''.join(prev_n.split('.')[:-1]) != ''.join(n.split('.')[:-1]): + last_bn.append(prev_n) + prev_n = n + last_bn.append(prev_n) + + init_fn = init_fn or init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n, last_bn=last_bn) + init_fn(m, n, last_bn=last_bn) diff --git a/dubhe-tadl/cream/lib/utils/flops_table.py b/dubhe-tadl/cream/lib/utils/flops_table.py new file mode 100644 index 0000000..254241a --- /dev/null +++ b/dubhe-tadl/cream/lib/utils/flops_table.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import torch + +from ptflops import get_model_complexity_info + + +class FlopsEst(object): + def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'): + self.block_num = len(model.blocks) + self.choice_num = len(model.blocks[0]) + self.flops_dict = {} + self.params_dict = {} + + if device == 'cpu': + model = model.cpu() + else: + model = model.cuda() + + self.params_fixed = 0 + self.flops_fixed = 0 + + input = torch.randn(input_shape) + + flops, params = get_model_complexity_info( + model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + input = model.conv_stem(input) + + for block_id, block in enumerate(model.blocks): + self.flops_dict[block_id] = {} + self.params_dict[block_id] = {} + for module_id, module in enumerate(block): + flops, params = get_model_complexity_info(module, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + # Flops(M) + self.flops_dict[block_id][module_id] = flops / 1e6 + # Params(M) + self.params_dict[block_id][module_id] = params / 1e6 + + input = module(input) + + # conv_last + flops, params = get_model_complexity_info(model.global_pool, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + input = model.global_pool(input) + + # globalpool + flops, params = get_model_complexity_info(model.conv_head, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + # return params (M) + def get_params(self, arch): + params = 0 + for block_id, block in enumerate(arch): + if block == -1: + continue + params += self.params_dict[block_id][block] + return params + self.params_fixed + + # return flops (M) + def get_flops(self, arch): + flops = 0 + for block_id, block in enumerate(arch): + if block == 'LayerChoice1' or block_id == 'LayerChoice23': + continue + for idx, choice in enumerate(arch[block]): + flops += self.flops_dict[block_id][idx] * (1 if choice else 0) + return flops + self.flops_fixed diff --git a/dubhe-tadl/cream/lib/utils/op_by_layer_dict.py b/dubhe-tadl/cream/lib/utils/op_by_layer_dict.py new file mode 100644 index 0000000..47ca509 --- /dev/null +++ b/dubhe-tadl/cream/lib/utils/op_by_layer_dict.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +# This dictionary is generated from calculating each operation of each layer to quickly search for layers. +# flops_op_dict[which_stage][which_operation] = +# (flops_of_operation_with_stride1, flops_of_operation_with_stride2) + +flops_op_dict = {} +for i in range(5): + flops_op_dict[i] = {} +flops_op_dict[0][0] = (21.828704, 18.820752) +flops_op_dict[0][1] = (32.669328, 28.16048) +flops_op_dict[0][2] = (25.039968, 23.637648) +flops_op_dict[0][3] = (37.486224, 35.385824) +flops_op_dict[0][4] = (29.856864, 30.862992) +flops_op_dict[0][5] = (44.711568, 46.22384) +flops_op_dict[1][0] = (11.808656, 11.86712) +flops_op_dict[1][1] = (17.68624, 17.780848) +flops_op_dict[1][2] = (13.01288, 13.87416) +flops_op_dict[1][3] = (19.492576, 20.791408) +flops_op_dict[1][4] = (14.819216, 16.88472) +flops_op_dict[1][5] = (22.20208, 25.307248) +flops_op_dict[2][0] = (8.198, 10.99632) +flops_op_dict[2][1] = (12.292848, 16.5172) +flops_op_dict[2][2] = (8.69976, 11.99984) +flops_op_dict[2][3] = (13.045488, 18.02248) +flops_op_dict[2][4] = (9.4524, 13.50512) +flops_op_dict[2][5] = (14.174448, 20.2804) +flops_op_dict[3][0] = (12.006112, 15.61632) +flops_op_dict[3][1] = (18.028752, 23.46096) +flops_op_dict[3][2] = (13.009632, 16.820544) +flops_op_dict[3][3] = (19.534032, 25.267296) +flops_op_dict[3][4] = (14.514912, 18.62688) +flops_op_dict[3][5] = (21.791952, 27.9768) +flops_op_dict[4][0] = (11.307456, 15.292416) +flops_op_dict[4][1] = (17.007072, 23.1504) +flops_op_dict[4][2] = (11.608512, 15.894528) +flops_op_dict[4][3] = (17.458656, 24.053568) +flops_op_dict[4][4] = (12.060096, 16.797696) +flops_op_dict[4][5] = (18.136032, 25.40832) \ No newline at end of file diff --git a/dubhe-tadl/cream/lib/utils/search_structure_supernet.py b/dubhe-tadl/cream/lib/utils/search_structure_supernet.py new file mode 100644 index 0000000..520d207 --- /dev/null +++ b/dubhe-tadl/cream/lib/utils/search_structure_supernet.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + + +def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum): + sta_num = [1, 1, 1, 1, 1] + order = [2, 3, 4, 1, 0, 2, 3, 4, 1, 0] + limits = [3, 3, 3, 2, 2, 4, 4, 4, 4, 4] + size_factor = 224 // 32 + base_min_flops = sum([flops_op_dict[i][0][0] for i in range(5)]) + base_max_flops = sum([flops_op_dict[i][5][0] for i in range(5)]) + + if base_min_flops > flops_maximum: + while base_min_flops > flops_maximum and size_factor >= 2: + size_factor = size_factor - 1 + flops_minimum = flops_minimum * (7. / size_factor) + flops_maximum = flops_maximum * (7. / size_factor) + if size_factor < 2: + return None, None, None + elif base_max_flops < flops_minimum: + cur_ptr = 0 + while base_max_flops < flops_minimum and cur_ptr <= 9: + if sta_num[order[cur_ptr]] >= limits[cur_ptr]: + cur_ptr += 1 + continue + base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1] + sta_num[order[cur_ptr]] += 1 + if cur_ptr > 7 and base_max_flops < flops_minimum: + return None, None, None + + cur_ptr = 0 + while cur_ptr <= 9: + if sta_num[order[cur_ptr]] >= limits[cur_ptr]: + cur_ptr += 1 + continue + base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1] + if base_max_flops <= flops_maximum: + sta_num[order[cur_ptr]] += 1 + else: + break + + arch_def = [item[:i] for i, item in zip([1] + sta_num + [1], arch_def)] + # print(arch_def) + + return sta_num, arch_def, size_factor * 32 diff --git a/dubhe-tadl/cream/lib/utils/util.py b/dubhe-tadl/cream/lib/utils/util.py new file mode 100644 index 0000000..5352352 --- /dev/null +++ b/dubhe-tadl/cream/lib/utils/util.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import sys +import torch +import logging +import argparse +import torch +import torch.nn as nn + +from copy import deepcopy +from torch import optim as optim +from thop import profile, clever_format + +from timm.utils import * + +from ..config import cfg + + +def get_path_acc(model, path, val_loader, args, val_iters=50): + prec1_m = AverageMeter() + prec5_m = AverageMeter() + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(val_loader): + if batch_idx >= val_iters: + break + if not args.prefetcher: + input = input.cuda() + target = target.cuda() + + output = model(input, path) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold( + 0, + reduce_factor, + reduce_factor).mean( + dim=2) + target = target[0:target.size(0):reduce_factor] + + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + torch.cuda.synchronize() + + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + return (prec1_m.avg, prec5_m.avg) + + +def get_logger(file_path): + """ Make python logger """ + log_format = '%(asctime)s | %(message)s' + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + logger = logging.getLogger() + logger.setLevel(logging.INFO) + formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') + file_handler = logging.FileHandler(file_path) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + return logger + + +def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + meta_layer_no_decay = [] + meta_layer_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith( + ".bias") or name in skip_list: + if 'meta_layer' in name: + meta_layer_no_decay.append(param) + else: + no_decay.append(param) + else: + if 'meta_layer' in name: + meta_layer_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0., 'lr': args.lr}, + {'params': decay, 'weight_decay': weight_decay, 'lr': args.lr}, + {'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr}, + {'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr}, + ] + + +def create_optimizer_supernet(args, model, has_apex=False, filter_bias_and_bn=True): + weight_decay = args.weight_decay + if 'adamw' == args.opt or 'radam' == args.opt : + weight_decay /= args.lr + if weight_decay and filter_bias_and_bn: + parameters = add_weight_decay_supernet(model, args, weight_decay) + weight_decay = 0. + else: + parameters = model.parameters() + + if 'fused' == args.opt: + assert has_apex and torch.cuda.is_available( + ), 'APEX and CUDA required for fused optimizers' + + if args.opt == 'sgd' or args.opt == 'nesterov': + optimizer = optim.SGD( + parameters, + momentum=args.momentum, + weight_decay=weight_decay, + nesterov=True) + elif args.opt == 'momentum': + optimizer = optim.SGD( + parameters, + momentum=args.momentum, + weight_decay=weight_decay, + nesterov=False) + elif args.opt == 'adam': + optimizer = optim.Adam( + parameters, weight_decay=weight_decay, eps=args.opt_eps) + else: + assert False and "Invalid optimizer" + raise ValueError + + return optimizer + + +def convert_lowercase(cfg): + keys = cfg.keys() + lowercase_keys = [key.lower() for key in keys] + values = [cfg.get(key) for key in keys] + for lowercase_key, value in zip(lowercase_keys, values): + cfg.setdefault(lowercase_key, value) + return cfg + +# +# def parse_config_args(exp_name): +# parser = argparse.ArgumentParser(description=exp_name) +# parser.add_argument( +# '--cfg', +# type=str, +# default='../experiments/workspace/retrain/retrain.yaml', +# help='configuration of cream') +# parser.add_argument('--local_rank', type=int, default=0, +# help='local_rank') +# args = parser.parse_args() +# +# cfg.merge_from_file(args.cfg) +# converted_cfg = convert_lowercase(cfg) +# +# return args, converted_cfg + + +def get_model_flops_params(model, input_size=(1, 3, 224, 224)): + input = torch.randn(input_size) + macs, params = profile(deepcopy(model), inputs=(input,), verbose=False) + macs, params = clever_format([macs, params], "%.3f") + return macs, params + + +def cross_entropy_loss_with_soft_target(pred, soft_target): + logsoftmax = nn.LogSoftmax() + return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) + + +def create_supernet_scheduler(optimizer, epochs, num_gpu, batch_size, lr): + ITERS = epochs * \ + (1280000 / (num_gpu * batch_size)) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: ( + lr - step / ITERS) if step <= ITERS else 0, last_epoch=-1) + return lr_scheduler, epochs diff --git a/dubhe-tadl/cream/pretrained/README.md b/dubhe-tadl/cream/pretrained/README.md new file mode 100644 index 0000000..bb0f24b --- /dev/null +++ b/dubhe-tadl/cream/pretrained/README.md @@ -0,0 +1,6 @@ +## Pretrained models + +The official 14M/43M/114M/287M/481M/604M pretrained models in +[google drive](https://drive.google.com/drive/folders/1CQjyBryZ4F20Rutj7coF8HWFcedApUn2) or +[Models-Baidu Disk (password: wqw6)](https://pan.baidu.com/s/1TqQNm2s14oEdyNPimw3T9g). + diff --git a/dubhe-tadl/cream/retrainer.py b/dubhe-tadl/cream/retrainer.py new file mode 100644 index 0000000..1772869 --- /dev/null +++ b/dubhe-tadl/cream/retrainer.py @@ -0,0 +1,462 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com +import sys + +sys.path.append('../..') +import os +import json +import time +import timm +import torch +import numpy as np +import torch.nn as nn + +from argparse import ArgumentParser +# from torch.utils.tensorboard import SummaryWriter + +# import timm packages +from timm.optim import create_optimizer +from timm.models import resume_checkpoint +from timm.scheduler import create_scheduler +from timm.data import Dataset, create_loader +from timm.utils import CheckpointSaver, ModelEma, update_summary +from timm.loss import LabelSmoothingCrossEntropy + +# import apex as distributed package +try: + from apex import amp + from apex.parallel import DistributedDataParallel as DDP + from apex.parallel import convert_syncbn_model + + HAS_APEX = True +except ImportError as e: + print(e) + from torch.nn.parallel import DistributedDataParallel as DDP + + HAS_APEX = False + +# import models and training functions +from pytorch.utils import mkdirs, save_best_checkpoint, str2bool +from lib.core.test import validate +from lib.core.retrain import train_epoch +from lib.models.structures.childnet import gen_childnet +from lib.utils.util import get_logger, get_model_flops_params +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def parse_args(): + """See lib.utils.config""" + parser = ArgumentParser() + + # path + parser.add_argument("--best_checkpoint_dir", type=str, default='./output/best_checkpoint/') + parser.add_argument("--checkpoint_dir", type=str, default='./output/checkpoints/') + parser.add_argument("--data_dir", type=str, default='./data') + parser.add_argument("--experiment_dir", type=str, default='./') + parser.add_argument("--model_name", type=str, default='retrainer') + parser.add_argument("--log_path", type=str, default='output/log') + parser.add_argument("--result_path", type=str, default='output/result.json') + parser.add_argument("--best_selected_space_path", type=str, + default='output/selected_space.json') + + # int + parser.add_argument("--acc_gap", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--cooldown_epochs", type=int, default=10) + parser.add_argument("--decay_epochs", type=int, default=10) + parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--flops_minimum", type=int, default=0) + parser.add_argument("--flops_maximum", type=int, default=200) + parser.add_argument("--image_size", type=int, default=224) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--log_interval", type=int, default=50) + parser.add_argument("--meta_sta_epoch", type=int, default=20) + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--num_gpu", type=int, default=1) + parser.add_argument("--parience_epochs", type=int, default=10) + parser.add_argument("--pool_size", type=int, default=10) + parser.add_argument("--recovery_interval", type=int, default=10) + parser.add_argument("--trial_id", type=int, default=42) + parser.add_argument("--selection", type=int, default=-1) + parser.add_argument("--slice_num", type=int, default=4) + parser.add_argument("--tta", type=int, default=0) + parser.add_argument("--update_iter", type=int, default=1300) + parser.add_argument("--val_batch_mul", type=int, default=4) + parser.add_argument("--warmup_epochs", type=int, default=3) + parser.add_argument("--workers", type=int, default=4) + + # float + parser.add_argument("--color_jitter", type=float, default=0.4) + parser.add_argument("--decay_rate", type=float, default=0.1) + parser.add_argument("--dropout_rate", type=float, default=0.0) + parser.add_argument("--ema_decay", type=float, default=0.998) + parser.add_argument("--lr", type=float, default=1e-2) + parser.add_argument("--meta_lr", type=float, default=1e-4) + parser.add_argument("--re_prob", type=float, default=0.2) + parser.add_argument("--opt_eps", type=float, default=1e-2) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--min_lr", type=float, default=1e-5) + parser.add_argument("--smoothing", type=float, default=0.1) + parser.add_argument("--weight_decay", type=float, default=1e-4) + parser.add_argument("--warmup_lr", type=float, default=1e-4) + + # bool + parser.add_argument("--auto_resume", type=str2bool, default='False') + parser.add_argument("--dil_conv", type=str2bool, default='False') + parser.add_argument("--ema_cpu", type=str2bool, default='False') + parser.add_argument("--pin_mem", type=str2bool, default='True') + parser.add_argument("--resunit", type=str2bool, default='False') + parser.add_argument("--save_images", type=str2bool, default='False') + parser.add_argument("--sync_bn", type=str2bool, default='False') + parser.add_argument("--use_ema", type=str2bool, default='False') + parser.add_argument("--verbose", type=str2bool, default='False') + + # str + parser.add_argument("--aa", type=str, default='rand-m9-mstd0.5') + parser.add_argument("--eval_metrics", type=str, default='prec1') + # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"] + parser.add_argument("--gp", type=str, default='avg') + parser.add_argument("--interpolation", type=str, default='bilinear') + parser.add_argument("--opt", type=str, default='sgd') + parser.add_argument("--pick_method", type=str, default='meta') + parser.add_argument("--re_mode", type=str, default='pixel') + parser.add_argument("--sched", type=str, default='sgd') + + args = parser.parse_args() + args.sync_bn = False + args.verbose = False + args.data_dir = args.data_dir + "/imagenet" + return args + + +def main(): + args = parse_args() + + mkdirs(args.checkpoint_dir + "/", + args.experiment_dir, + args.best_selected_space_path, + args.result_path) + with open(args.result_path, "w") as ss_file: + ss_file.write('') + + if len(args.checkpoint_dir > 1): + mkdirs(args.best_checkpoint_dir + "/") + + args.checkpoint_dir = os.path.join( + args.checkpoint_dir, + "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + ) + if not os.path.exists(args.checkpoint_dir): + os.mkdir(args.checkpoint_dir) + + # resolve logging + if args.local_rank == 0: + logger = get_logger(args.log_path) + writer = None # SummaryWriter(os.path.join(output_dir, 'runs')) + else: + writer, logger = None, None + + # retrain model selection + + if args.selection == -1: + if os.path.exists(args.best_selected_space_path): + with open(args.best_selected_space_path, "r") as f: + arch_list = json.load(f)['selected_space'] + else: + args.selection = 14 + logger.warning("args.best_selected_space_path is not exist. Set selection to 14.") + + if args.selection == 481: + arch_list = [ + [0], [ + 3, 4, 3, 1], [ + 3, 2, 3, 0], [ + 3, 3, 3, 1], [ + 3, 3, 3, 3], [ + 3, 3, 3, 3], [0]] + args.image_size = 224 + elif args.selection == 43: + arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] + args.image_size = 96 + elif args.selection == 14: + arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] + args.image_size = 64 + elif args.selection == 112: + arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] + args.image_size = 160 + elif args.selection == 287: + arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + args.image_size = 224 + elif args.selection == 604: + arch_list = [ + [0], [ + 3, 3, 2, 3, 3], [ + 3, 2, 3, 2, 3], [ + 3, 2, 3, 2, 3], [ + 3, 3, 2, 2, 3, 3], [ + 3, 3, 2, 3, 3, 3], [0]] + args.image_size = 224 + elif args.selection == -1: + args.image_size = 224 + else: + raise ValueError("Model Retrain Selection is not Supported!") + + print(arch_list) + # define childnet architecture from arch_list + stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] + + # TODO: this param from NNI is different from microsoft/Cream. + choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k3_s2_e6_c80_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k5_s2_e6_c192_se0.25'] + arch_def = [[stem[0]]] + [[choice_block_pool[idx] + for repeat_times in range(len(arch_list[idx + 1]))] + for idx in range(len(choice_block_pool))] + [[stem[1]]] + + # generate childnet + model = gen_childnet( + arch_list, + arch_def, + num_classes=args.num_classes, + drop_rate=args.dropout_rate, + global_pool=args.gp) + + # initialize distributed parameters + distributed = args.num_gpu > 1 + torch.cuda.set_device(args.local_rank) + if args.local_rank == 0: + logger.info( + 'Training on Process {} with {} GPUs.'.format( + args.local_rank, args.num_gpu)) + + # fix random seeds + torch.manual_seed(args.trial_id) + torch.cuda.manual_seed_all(args.trial_id) + np.random.seed(args.trial_id) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # get parameters and FLOPs of model + if args.local_rank == 0: + macs, params = get_model_flops_params(model, input_size=( + 1, 3, args.image_size, args.image_size)) + logger.info( + '[Model-{}] Flops: {} Params: {}'.format(args.selection, macs, params)) + + # create optimizer + model = model.cuda() + optimizer = create_optimizer(args, model) + + # optionally resume from a checkpoint + resume_epoch = None + if args.auto_resume: + if int(timm.__version__[2]) >= 3: + resume_epoch = resume_checkpoint(model, args.experiment_dir, optimizer) + else: + resume_state, resume_epoch = resume_checkpoint(model, args.experiment_dir) + optimizer.load_state_dict(resume_state['optimizer']) + del resume_state + + model_ema = None + if args.use_ema: + model_ema = ModelEma( + model, + decay=args.ema_decay, + device='cpu' if args.ema_cpu else '', + resume=args.experiment_dir if args.auto_resume else None) + + # initialize training parameters + eval_metric = args.eval_metrics + best_metric, best_epoch, saver = None, None, None + if args.local_rank == 0: + decreasing = True if eval_metric == 'loss' else False + if int(timm.__version__[2]) >= 3: + saver = CheckpointSaver(model, optimizer, + checkpoint_dir=args.checkpoint_dir, + recovery_dir=args.checkpoint_dir, + model_ema=model_ema, + decreasing=decreasing, + max_history=2) + else: + saver = CheckpointSaver( + checkpoint_dir=args.checkpoint_dir, + recovery_dir=args.checkpoint_dir, + decreasing=decreasing, + max_history=2) + + if distributed: + torch.distributed.init_process_group(backend='nccl', init_method='env://') + + if args.sync_bn: + try: + if HAS_APEX: + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + logger.info('Converted model to use Synchronized BatchNorm.') + except Exception as e: + if args.local_rank == 0: + logger.error( + 'Failed to enable Synchronized BatchNorm. ' + 'Install Apex or Torch >= 1.1 with exception {}'.format(e)) + if HAS_APEX: + model = DDP(model, delay_allreduce=True) + else: + if args.local_rank == 0: + logger.info( + "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") + # can use device str in Torch >= 1.1 + model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) + + # imagenet train dataset + train_dir = os.path.join(args.data_dir, 'train') + if not os.path.exists(train_dir) and args.local_rank == 0: + logger.error('Training folder does not exist at: {}'.format(train_dir)) + exit(1) + dataset_train = Dataset(train_dir) + loader_train = create_loader( + dataset_train, + input_size=(3, args.image_size, args.image_size), + batch_size=args.batch_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + num_aug_splits=0, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=args.workers, + distributed=distributed, + collate_fn=None, + pin_memory=args.pin_mem, + interpolation='random', + re_mode=args.re_mode, + re_prob=args.re_prob + ) + + # imagenet validation dataset + eval_dir = os.path.join(args.data_dir, 'val') + if not os.path.exists(eval_dir) and args.local_rank == 0: + logger.error( + 'Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, args.image_size, args.image_size), + batch_size=args.val_batch_mul * args.batch_size, + is_training=False, + interpolation=args.interpolation, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=args.workers, + distributed=distributed, + pin_memory=args.pin_mem + ) + + # whether to use label smoothing + if args.smoothing > 0.: + train_loss_fn = LabelSmoothingCrossEntropy( + smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + + # create learning rate scheduler + lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = resume_epoch if resume_epoch is not None else 0 + if start_epoch > 0: + lr_scheduler.step(start_epoch) + if args.local_rank == 0: + logger.info('Scheduled epochs: {}'.format(num_epochs)) + + try: + best_record, best_ep = 0, 0 + for epoch in range(start_epoch, num_epochs): + if distributed: + loader_train.sampler.set_epoch(epoch) + + train_metrics = train_epoch( + epoch, + model, + loader_train, + optimizer, + train_loss_fn, + args, + lr_scheduler=lr_scheduler, + saver=saver, + output_dir=args.checkpoint_dir, + model_ema=model_ema, + logger=logger, + writer=writer, + local_rank=args.local_rank) + + eval_metrics = validate( + epoch, + model, + loader_eval, + validate_loss_fn, + args, + logger=logger, + writer=writer, + local_rank=args.local_rank, + result_path=args.result_path + ) + + if model_ema is not None and not args.ema_cpu: + ema_eval_metrics = validate( + epoch, + model_ema.ema, + loader_eval, + validate_loss_fn, + args, + log_suffix='_EMA', + logger=logger, + writer=writer, + local_rank=args.local_rank + ) + eval_metrics = ema_eval_metrics + + if lr_scheduler is not None: + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + update_summary(epoch, train_metrics, eval_metrics, os.path.join( + args.checkpoint_dir, 'summary.csv'), write_header=best_metric is None) + + if saver is not None: + # save proper checkpoint with eval metric + save_metric = eval_metrics[eval_metric] + + if int(timm.__version__[2]) >= 3: + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + else: + best_metric, best_epoch = saver.save_checkpoint( + model, optimizer, args, + epoch=epoch, metric=save_metric) + + if best_record < eval_metrics[eval_metric]: + best_record = eval_metrics[eval_metric] + best_ep = epoch + + if args.local_rank == 0: + logger.info( + '*** Best metric: {0} (epoch {1})'.format(best_record, best_ep)) + + except KeyboardInterrupt: + pass + + if best_metric is not None: + logger.info( + '*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) + + +if __name__ == '__main__': + main() diff --git a/dubhe-tadl/cream/selector.py b/dubhe-tadl/cream/selector.py new file mode 100644 index 0000000..3e56222 --- /dev/null +++ b/dubhe-tadl/cream/selector.py @@ -0,0 +1,21 @@ +import sys + +sys.path.append('../..') +from pytorch.selector import Selector + + +class ClassicnasSelector(Selector): + def __init__(self, *args, single_candidate=True): + super().__init__(single_candidate) + self.args = args + + def fit(self): + """ + only one candatite, function passed + """ + pass + + +if __name__ == "__main__": + hpo_selector = ClassicnasSelector() + hpo_selector.fit() diff --git a/dubhe-tadl/cream/test.py b/dubhe-tadl/cream/test.py new file mode 100644 index 0000000..9e24a87 --- /dev/null +++ b/dubhe-tadl/cream/test.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import os +import warnings +import datetime +import torch +import torch.nn as nn + +# from torch.utils.tensorboard import SummaryWriter + +# import timm packages +from timm.utils import ModelEma +from timm.models import resume_checkpoint +from timm.data import Dataset, create_loader + +# import apex as distributed package +try: + from apex.parallel import convert_syncbn_model + from apex.parallel import DistributedDataParallel as DDP + + HAS_APEX = True +except ImportError as e: + print(e) + from torch.nn.parallel import DistributedDataParallel as DDP + + HAS_APEX = False + +# import models and training functions +from lib.core.test import validate +from lib.models.structures.childnet import gen_childnet +from lib.utils.util import parse_config_args, get_logger, get_model_flops_params +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def main(): + args, cfg = parse_config_args('child net testing') + + # resolve logging + output_dir = os.path.join(cfg.SAVE_PATH, + "{}-{}".format(datetime.date.today().strftime('%m%d'), + cfg.MODEL)) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if args.local_rank == 0: + logger = get_logger(os.path.join(output_dir, 'test.log')) + writer = None # SummaryWriter(os.path.join(output_dir, 'runs')) + else: + writer, logger = None, None + + # retrain model selection + if cfg.NET.SELECTION == 481: + arch_list = [ + [0], [ + 3, 4, 3, 1], [ + 3, 2, 3, 0], [ + 3, 3, 3, 1], [ + 3, 3, 3, 3], [ + 3, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 43: + arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 96 + elif cfg.NET.SELECTION == 14: + arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] + cfg.DATASET.IMAGE_SIZE = 64 + elif cfg.NET.SELECTION == 112: + arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 160 + elif cfg.NET.SELECTION == 287: + arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 604: + arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], + [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + else: + raise ValueError("Model Test Selection is not Supported!") + + # define childnet architecture from arch_list + stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] + # TODO: this param from NNI is different from microsoft/Cream. + choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k3_s2_e6_c80_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k5_s2_e6_c192_se0.25'] + arch_def = [[stem[0]]] + [[choice_block_pool[idx] + for repeat_times in range(len(arch_list[idx + 1]))] + for idx in range(len(choice_block_pool))] + [[stem[1]]] + + # generate childnet + model = gen_childnet( + arch_list, + arch_def, + num_classes=cfg.DATASET.NUM_CLASSES, + drop_rate=cfg.NET.DROPOUT_RATE, + global_pool=cfg.NET.GP) + + if args.local_rank == 0: + macs, params = get_model_flops_params(model, input_size=( + 1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) + logger.info( + '[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params)) + + # initialize distributed parameters + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.local_rank == 0: + logger.info( + "Training on Process {} with {} GPUs.".format( + args.local_rank, cfg.NUM_GPU)) + + # resume model from checkpoint + assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH) + resume_checkpoint(model, cfg.RESUME_PATH) + + model = model.cuda() + + model_ema = None + if cfg.NET.EMA.USE: + # Important to create EMA model after cuda(), DP wrapper, and AMP but + # before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=cfg.NET.EMA.DECAY, + device='cpu' if cfg.NET.EMA.FORCE_CPU else '', + resume=cfg.RESUME_PATH) + + # imagenet validation dataset + eval_dir = os.path.join(cfg.DATA_DIR, 'val') + if not os.path.exists(eval_dir) and args.local_rank == 0: + logger.error( + 'Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, + is_training=False, + num_workers=cfg.WORKERS, + distributed=True, + pin_memory=cfg.DATASET.PIN_MEM, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD + ) + + # only test accuracy of model-EMA + validate_loss_fn = nn.CrossEntropyLoss().cuda() + validate(0, model, loader_eval, validate_loss_fn, cfg, + log_suffix='_EMA', logger=logger, + writer=writer, local_rank=args.local_rank) + + if cfg.NET.EMA.USE: + validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg, + log_suffix='_EMA', logger=logger, + writer=writer, local_rank=args.local_rank) + + +if __name__ == '__main__': + main() diff --git a/dubhe-tadl/cream/trainer.py b/dubhe-tadl/cream/trainer.py new file mode 100644 index 0000000..e815c6a --- /dev/null +++ b/dubhe-tadl/cream/trainer.py @@ -0,0 +1,312 @@ +# https://github.com/microsoft/nni/blob/v2.0/examples/nas/cream/train.py +import sys + +sys.path.append('../..') + +import os +import sys +import time +import json +import torch +import numpy as np +import torch.nn as nn + +from argparse import ArgumentParser + +# import timm packages +from timm.loss import LabelSmoothingCrossEntropy +from timm.data import Dataset, create_loader +from timm.models import resume_checkpoint + +# import apex as distributed package +# try: +# from apex.parallel import DistributedDataParallel as DDP +# from apex.parallel import convert_syncbn_model +# +# USE_APEX = True +# except ImportError as e: +# print(e) +# from torch.nn.parallel import DistributedDataParallel as DDP +# +# USE_APEX = False + +# import models and training functions +from lib.utils.flops_table import FlopsEst +from lib.models.structures.supernet import gen_supernet +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from lib.utils.util import get_logger, \ + create_optimizer_supernet, create_supernet_scheduler + +from pytorch.utils import mkdirs, str2bool +from pytorch.callbacks import LRSchedulerCallback +from pytorch.callbacks import ModelCheckpoint +from algorithms import CreamSupernetTrainer +from algorithms import RandomMutator + + +def parse_args(): + """See lib.utils.config""" + parser = ArgumentParser() + + # path + parser.add_argument("--checkpoint_dir", type=str, default='') + parser.add_argument("--data_dir", type=str, default='./data') + parser.add_argument("--experiment_dir", type=str, default='./') + parser.add_argument("--model_name", type=str, default='trainer') + parser.add_argument("--log_path", type=str, default='output/log') + parser.add_argument("--result_path", type=str, default='output/result.json') + parser.add_argument("--search_space_path", type=str, default='output/search_space.json') + parser.add_argument("--best_selected_space_path", type=str, + default='output/selected_space.json') + + # int + parser.add_argument("--acc_gap", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--flops_minimum", type=int, default=0) + parser.add_argument("--flops_maximum", type=int, default=200) + parser.add_argument("--image_size", type=int, default=224) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--log_interval", type=int, default=50) + parser.add_argument("--meta_sta_epoch", type=int, default=20) + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--num_gpu", type=int, default=1) + parser.add_argument("--pool_size", type=int, default=10) + parser.add_argument("--trial_id", type=int, default=42) + parser.add_argument("--slice_num", type=int, default=4) + parser.add_argument("--tta", type=int, default=0) + parser.add_argument("--update_iter", type=int, default=1300) + parser.add_argument("--workers", type=int, default=4) + + # float + parser.add_argument("--color_jitter", type=float, default=0.4) + parser.add_argument("--dropout_rate", type=float, default=0.0) + parser.add_argument("--lr", type=float, default=1e-2) + parser.add_argument("--meta_lr", type=float, default=1e-4) + parser.add_argument("--opt_eps", type=float, default=1e-2) + parser.add_argument("--re_prob", type=float, default=0.2) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--smoothing", type=float, default=0.1) + parser.add_argument("--weight_decay", type=float, default=1e-4) + + # bool + parser.add_argument("--auto_resume", type=str2bool, default='False') + parser.add_argument("--dil_conv", type=str2bool, default='False') + parser.add_argument("--resunit", type=str2bool, default='False') + parser.add_argument("--sync_bn", type=str2bool, default='False') + parser.add_argument("--verbose", type=str2bool, default='False') + + # str + # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"] + parser.add_argument("--gp", type=str, default='avg') + parser.add_argument("--interpolation", type=str, default='bilinear') + parser.add_argument("--opt", type=str, default='sgd') + parser.add_argument("--pick_method", type=str, default='meta') + parser.add_argument("--re_mode", type=str, default='pixel') + + args = parser.parse_args() + args.sync_bn = False + args.verbose = False + args.data_dir = args.data_dir + "/imagenet" + return args + + +def main(): + args = parse_args() + + mkdirs(args.experiment_dir, + args.best_selected_space_path, + args.search_space_path, + args.result_path, + args.log_path) + + with open(args.result_path, "w") as ss_file: + ss_file.write('') + + # resolve logging + + if len(args.checkpoint_dir > 1): + mkdirs(args.checkpoint_dir + "/") + args.checkpoint_dir = os.path.join( + args.checkpoint_dir, + "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + ) + if not os.path.exists(args.checkpoint_dir): + os.mkdir(args.checkpoint_dir) + + if args.local_rank == 0: + logger = get_logger(args.log_path) + else: + logger = None + + # initialize distributed parameters + torch.cuda.set_device(args.local_rank) + # torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.local_rank == 0: + logger.info( + 'Training on Process %d with %d GPUs.', + args.local_rank, args.num_gpu) + + # fix random seeds + torch.manual_seed(args.trial_id) + torch.cuda.manual_seed_all(args.trial_id) + np.random.seed(args.trial_id) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # generate supernet and optimizer + model, sta_num, resolution, search_space = gen_supernet( + flops_minimum=args.flops_minimum, + flops_maximum=args.flops_maximum, + num_classes=args.num_classes, + drop_rate=args.dropout_rate, + global_pool=args.gp, + resunit=args.resunit, + dil_conv=args.dil_conv, + slice=args.slice_num, + verbose=args.verbose, + logger=logger) + optimizer = create_optimizer_supernet(args, model) + + # number of choice blocks in supernet + choice_num = len(model.blocks[7]) + if args.local_rank == 0: + logger.info('Supernet created, param count: %d', ( + sum([m.numel() for m in model.parameters()]))) + logger.info('resolution: %d', resolution) + logger.info('choice number: %d', choice_num) + with open(args.search_space_path, "w") as f: + print("dump search space.") + json.dump({'search_space': search_space}, f) + + # initialize flops look-up table + model_est = FlopsEst(model) + flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed + model = model.cuda() + + # convert model to distributed mode + if args.sync_bn: + try: + # if USE_APEX: + # model = convert_syncbn_model(model) + # else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + logger.info('Converted model to use Synchronized BatchNorm.') + except Exception as exception: + logger.info( + 'Failed to enable Synchronized BatchNorm. ' + 'Install Apex or Torch >= 1.1 with Exception %s', exception) + # if USE_APEX: + # model = DDP(model, delay_allreduce=True) + # else: + # if args.local_rank == 0: + # logger.info( + # "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") + # # can use device str in Torch >= 1.1 + # model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) + + # optionally resume from a checkpoint + resume_epoch = None + if False: # args.auto_resume: + checkpoint = torch.load(args.experiment_dir) + + model.load_state_dict(checkpoint['child_model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + resume_epoch = checkpoint['epoch'] + + # create learning rate scheduler + lr_scheduler, num_epochs = create_supernet_scheduler(optimizer, args.epochs, args.num_gpu, + args.batch_size, args.lr) + + start_epoch = resume_epoch if resume_epoch is not None else 0 + if start_epoch > 0: + lr_scheduler.step(start_epoch) + + if args.local_rank == 0: + logger.info('Scheduled epochs: %d', num_epochs) + + # imagenet train dataset + train_dir = os.path.join(args.data_dir, 'train') + if not os.path.exists(train_dir): + logger.info('Training folder does not exist at: %s', train_dir) + sys.exit() + + dataset_train = Dataset(train_dir) + loader_train = create_loader( + dataset_train, + input_size=(3, args.image_size, args.image_size), + batch_size=args.batch_size, + is_training=True, + use_prefetcher=True, + re_prob=args.re_prob, + re_mode=args.re_mode, + color_jitter=args.color_jitter, + interpolation='random', + num_workers=args.workers, + distributed=False, + collate_fn=None, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD + ) + + # imagenet validation dataset + eval_dir = os.path.join(args.data_dir, 'val') + if not os.path.isdir(eval_dir): + logger.info('Validation folder does not exist at: %s', eval_dir) + sys.exit() + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, args.image_size, args.image_size), + batch_size=4 * args.batch_size, + is_training=False, + use_prefetcher=True, + num_workers=args.workers, + distributed=False, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + interpolation=args.interpolation + ) + + # whether to use label smoothing + if args.smoothing > 0.: + train_loss_fn = LabelSmoothingCrossEntropy( + smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + + mutator = RandomMutator(model) + + _callbacks = [LRSchedulerCallback(lr_scheduler)] + if len(args.checkpoint_dir) > 1: + _callbacks.append(ModelCheckpoint(checkpoint_dir)) + trainer = CreamSupernetTrainer(args.best_selected_space_path, model, train_loss_fn, + validate_loss_fn, + optimizer, num_epochs, loader_train, loader_eval, + result_path=args.result_path, + mutator=mutator, + batch_size=args.batch_size, + log_frequency=args.log_interval, + meta_sta_epoch=args.meta_sta_epoch, + update_iter=args.update_iter, + slices=args.slice_num, + pool_size=args.pool_size, + pick_method=args.pick_method, + choice_num=choice_num, + sta_num=sta_num, + acc_gap=args.acc_gap, + flops_dict=flops_dict, + flops_fixed=flops_fixed, + local_rank=args.local_rank, + callbacks=_callbacks) + + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/dubhe-tadl/darts/__init__.py b/dubhe-tadl/darts/__init__.py new file mode 100644 index 0000000..34b4bca --- /dev/null +++ b/dubhe-tadl/darts/__init__.py @@ -0,0 +1,2 @@ +from pytorch.darts.dartstrainer import DartsTrainer +from pytorch.darts.dartsmutator import DartsMutator \ No newline at end of file diff --git a/dubhe-tadl/darts/darts_retrain.py b/dubhe-tadl/darts/darts_retrain.py new file mode 100644 index 0000000..7996a9a --- /dev/null +++ b/dubhe-tadl/darts/darts_retrain.py @@ -0,0 +1,205 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn +# from torch.utils.tensorboard import SummaryWriter + +import datasets +import utils +from model import CNN + +from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter +from pytorch.fixed import apply_fixed_architecture +from pytorch.retrainer import Retrainer + +logger = logging.getLogger(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# writer = SummaryWriter() + +class DartsRetrainer(Retrainer): + def __init__(self, aux_weight, grad_clip, epochs, log_frequency): + self.aux_weight = aux_weight + self.grad_clip = grad_clip + self.epochs = epochs + self.log_frequency = log_frequency + + def train(self, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]["lr"] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + # writer.add_scalar("lr", cur_lr, global_step=cur_step) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if self.aux_weight > 0.: + loss += self.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + # writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) + # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) + + if step % self.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + + def validate(self, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % self.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + # writer.add_scalar("loss/test", losses.avg, global_step=cur_step) + # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) + # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + return top1.avg + +if __name__ == "__main__": + parser = ArgumentParser("DARTS retrain") + parser.add_argument("--data_dir", type=str, + default='./data/', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='.0/result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='.0/log', help="log for info") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument("--best_checkpoint_dir", type=str, + default='./', help="default name is best_checkpoint_epoch{}.pth") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--lr", default=0.025, type=float) + parser.add_argument("--batch_size", default=128, type=int) + parser.add_argument("--log_frequency", default=10, type=int) + parser.add_argument("--epochs", default=5, type=int) + parser.add_argument("--aux_weight", default=0.4, type=float) + parser.add_argument("--drop_path_prob", default=0.2, type=float) + parser.add_argument("--workers", default=4, type=int) + parser.add_argument("--channels", default=36, type=int) + parser.add_argument("--grad_clip", default=5., type=float) + parser.add_argument("--class_num", default=10, type=int, help="cifar10") + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) + init_logger(args.log_path) + logger.info(args) + set_seed(args.trial_id) + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) + + model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True) + apply_fixed_architecture(model, args.best_selected_space_path) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + retrainer = DartsRetrainer(aux_weight=args.aux_weight, + grad_clip=args.grad_clip, + epochs=args.epochs, + log_frequency = args.log_frequency) + + # result = {"Accuracy": [], "Cost_time": ''} + best_top1 = 0. + start_time = time.time() + with open(args.result_path, "w") as file: + file.write('') + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + retrainer.train(train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) + # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} + logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') + # result["Accuracy"].append(top1) + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) + cost_time = time.time() - start_time + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) + + # result["Cost_time"] = str(cost_time) + ' s' + # dump_global_result(args.result_path, result) + save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) + logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) + diff --git a/dubhe-tadl/darts/darts_select.py b/dubhe-tadl/darts/darts_select.py new file mode 100644 index 0000000..24685e5 --- /dev/null +++ b/dubhe-tadl/darts/darts_select.py @@ -0,0 +1,21 @@ +import sys +sys.path.append('../..') +from pytorch.selector import Selector +from argparse import ArgumentParser + + +class DartsSelector(Selector): + def __init__(self, single_candidate=True): + super().__init__(single_candidate) + + def fit(self): + pass + +if __name__ == "__main__": + parser = ArgumentParser("DARTS select") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + + args = parser.parse_args() + darts_selector = DartsSelector(True) + darts_selector.fit() \ No newline at end of file diff --git a/dubhe-tadl/darts/darts_train.py b/dubhe-tadl/darts/darts_train.py new file mode 100644 index 0000000..d1c8f8e --- /dev/null +++ b/dubhe-tadl/darts/darts_train.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import sys +sys.path.append('..'+ '/' + '..') + +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import datasets +from model import CNN +from utils import accuracy +from dartstrainer import DartsTrainer +from pytorch.utils import * +from pytorch.callbacks import BestArchitectureCheckpoint, LRSchedulerCallback + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + parser = ArgumentParser("DARTS train") + parser.add_argument("--data_dir", type=str, + default='../data/', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='.0/result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='.0/log', help="log for info") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search space of PDARTS") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--layers", default=8, type=int) + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--log_frequency", default=10, type=int) + parser.add_argument("--epochs", default=5, type=int) + parser.add_argument("--channels", default=16, type=int) + parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights') + parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture') + parser.add_argument("--unrolled", default=False, action="store_true") + parser.add_argument("--visualization", default=False, action="store_true") + parser.add_argument("--class_num", default=10, type=int, help="cifar10") + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) + init_logger(args.log_path, "info") + logger.info(args) + set_seed(args.trial_id) + + dataset_train, dataset_valid = datasets.get_dataset("cifar10", root=args.data_dir) + model = CNN(32, 3, args.channels, args.class_num, args.layers) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), args.model_lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + optimizer=optim, + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + search_space_path = args.search_space_path, + batch_size=args.batch_size, + log_frequency=args.log_frequency, + result_path=args.result_path, + unrolled=args.unrolled, + arch_lr=args.arch_lr, + callbacks=[LRSchedulerCallback(lr_scheduler), BestArchitectureCheckpoint(args.best_selected_space_path, args.epochs)]) + + if args.visualization: + trainer.enable_visualization() + t1 = time.time() + trainer.train() + # res_json = trainer.result + cost_time = time.time() - t1 + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) + + # res_json["Cost_time"] = str(cost_time) + ' s' + # dump_global_result(args.result_path, res_json) \ No newline at end of file diff --git a/dubhe-tadl/darts/dartsmutator.py b/dubhe-tadl/darts/dartsmutator.py new file mode 100644 index 0000000..15dc974 --- /dev/null +++ b/dubhe-tadl/darts/dartsmutator.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from pytorch.mutator import Mutator +from pytorch.mutables import LayerChoice, InputChoice + +_logger = logging.getLogger(__name__) + +class DartsMutator(Mutator): + """ + Connects the model in a DARTS (differentiable) way. + + An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no + op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false`` + (not chosen). + + All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based + on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice + will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0. + + It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the + value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will + be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully. + + Attributes + ---------- + choices: ParameterDict + dict that maps keys of LayerChoices to weighted-connection float tensors. + """ + def __init__(self, model): + super().__init__(model) + self.choices = nn.ParameterDict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) + + def device(self): + for v in self.choices.values(): + return v.device + + def sample_search(self): + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] + elif isinstance(mutable, InputChoice): + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) + return result + + def sample_final(self): + result = dict() + edges_max = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) + edges_max[mutable.key] = max_val + result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool() + for mutable in self.mutables: + if isinstance(mutable, InputChoice): + if mutable.n_chosen is not None: + weights = [] + for src_key in mutable.choose_from: + if src_key not in edges_max: + _logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key) + weights.append(edges_max.get(src_key, 0.)) + weights = torch.tensor(weights) # pylint: disable=not-callable + _, topk_edge_indices = torch.topk(weights, mutable.n_chosen) + selected_multihot = [] + for i, src_key in enumerate(mutable.choose_from): + if i not in topk_edge_indices and src_key in result: + # If an edge is never selected, there is no need to calculate any op on this edge. + # This is to eliminate redundant calculation. + result[src_key] = torch.zeros_like(result[src_key]) + selected_multihot.append(i in topk_edge_indices) + result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable + else: + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable + return result + + def _generate_search_space(self): + """ + Generate search space from mutables. + Here is the search space format: + :: + { key_name: {"_type": "layer_choice", + "_value": ["conv1", "conv2"]} } + { key_name: {"_type": "input_choice", + "_value": {"candidates": ["in1", "in2"], + "n_chosen": 1}} } + Returns + ------- + dict + the generated search space + """ + res = OrderedDict() + res["op_list"] = OrderedDict() + res["search_space"] = OrderedDict() + # res["normal_cell"] = OrderedDict(), + # res["reduction_cell"] = OrderedDict() + + keys = [] + for mutable in self.mutables: + # for now we only generate flattened search space + if (len(res["search_space"])) >= 36: + break + + if isinstance(mutable, LayerChoice): + key = mutable.key + if key not in keys: + val = mutable.names + if not res["op_list"]: + res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} + # node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][key] = "op_list" + keys.append(key) + + elif isinstance(mutable, InputChoice): + key = mutable.key + if key not in keys: + # node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][key] = {"_type": "input_choice", + "_value": {"candidates": mutable.choose_from, + "n_chosen": mutable.n_chosen}} + keys.append(key) + else: + raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) + + return res diff --git a/dubhe-tadl/darts/dartstrainer.py b/dubhe-tadl/darts/dartstrainer.py new file mode 100644 index 0000000..489fae1 --- /dev/null +++ b/dubhe-tadl/darts/dartstrainer.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import logging + +import torch +import torch.nn as nn +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup, dump_global_result +from pytorch.darts.dartsmutator import DartsMutator +import json + +logger = logging.getLogger(__name__) + +class DartsTrainer(Trainer): + """ + DARTS trainer. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + optimizer : Optimizer + The optimizer used for optimizing the model. + num_epochs : int + Number of epochs planned for training. + dataset_train : Dataset + Dataset for training. Will be split for training weights and architecture weights. + dataset_valid : Dataset + Dataset for testing. + mutator : DartsMutator + Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + callbacks : list of Callback + list of callbacks to trigger at events. + arch_lr : float + Learning rate of architecture parameters. + unrolled : float + ``True`` if using second order optimization, else first order optimization. + """ + def __init__(self, model, loss, metrics, + optimizer, num_epochs, dataset_train, dataset_valid, search_space_path, result_path, num_pre_epochs=0, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None, arch_lr=3.0E-4, unrolled=False): + super().__init__(model, mutator if mutator is not None else DartsMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) + + self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arch_lr, betas=(0.5, 0.999), weight_decay=1.0E-3) + self.unrolled = unrolled + self.num_pre_epoches = num_pre_epochs + self.result_path = result_path + with open(self.result_path, "w") as file: + file.write('') + n_train = len(self.dataset_train) + split = n_train // 2 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=batch_size, + num_workers=workers) + if search_space_path is not None: + dump_global_result(search_space_path, self.mutator._generate_search_space()) + + # self.result = {"Accuracy": []} + + def train_one_epoch(self, epoch): + self.model.train() + self.mutator.train() + meters = AverageMeterGroup() + # t1 = time() + for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): + trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) + val_X, val_y = val_X.to(self.device), val_y.to(self.device) + + if epoch >= self.num_pre_epoches: + # phase 1. architecture step + self.ctrl_optim.zero_grad() + if self.unrolled: + self._unrolled_backward(trn_X, trn_y, val_X, val_y) + else: + self._backward(val_X, val_y) + self.ctrl_optim.step() + + # phase 2: child network step + self.optimizer.zero_grad() + logits, loss = self._logits_and_loss(trn_X, trn_y) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping + self.optimizer.step() + + metrics = self.metrics(logits, trn_y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) + + def validate_one_epoch(self, epoch, log_print=True): + self.model.eval() + self.mutator.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + self.mutator.reset() + for step, (X, y) in enumerate(self.test_loader): + X, y = X.to(self.device), y.to(self.device) + logits = self.model(X) + metrics = self.metrics(logits, y) + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.test_loader), meters) + if log_print: + # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} + logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + with open(self.result_path, "a") as file: + file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + '\n') + # self.result["Accuracy"].append(meters.get_last_acc()) + + def _logits_and_loss(self, X, y): + self.mutator.reset() + logits = self.model(X) + loss = self.loss(logits, y) + # self._write_graph_status() + return logits, loss + + def _backward(self, val_X, val_y): + """ + Simple backward with gradient descent + """ + _, loss = self._logits_and_loss(val_X, val_y) + loss.backward() + + def _unrolled_backward(self, trn_X, trn_y, val_X, val_y): + """ + Compute unrolled loss and backward its gradients + """ + backup_params = copy.deepcopy(tuple(self.model.parameters())) + + # do virtual step on training data + lr = self.optimizer.param_groups[0]["lr"] + momentum = self.optimizer.param_groups[0]["momentum"] + weight_decay = self.optimizer.param_groups[0]["weight_decay"] + self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay) + + # calculate unrolled loss on validation data + # keep gradients for model here for compute hessian + _, loss = self._logits_and_loss(val_X, val_y) + w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters()) + w_grads = torch.autograd.grad(loss, w_model + w_ctrl) + d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):] + + # compute hessian and final gradients + hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y) + with torch.no_grad(): + for param, d, h in zip(w_ctrl, d_ctrl, hessian): + # gradient = dalpha - lr * hessian + param.grad = d - lr * h + + # restore weights + self._restore_weights(backup_params) + + def _compute_virtual_model(self, X, y, lr, momentum, weight_decay): + """ + Compute unrolled weights w` + """ + # don't need zero_grad, using autograd to calculate gradients + _, loss = self._logits_and_loss(X, y) + gradients = torch.autograd.grad(loss, self.model.parameters()) + with torch.no_grad(): + for w, g in zip(self.model.parameters(), gradients): + m = self.optimizer.state[w].get("momentum_buffer", 0.) + w = w - lr * (momentum * m + g + weight_decay * w) + + def _restore_weights(self, backup_params): + with torch.no_grad(): + for param, backup in zip(self.model.parameters(), backup_params): + param.copy_(backup) + + def _compute_hessian(self, backup_params, dw, trn_X, trn_y): + """ + dw = dw` { L_val(w`, alpha) } + w+ = w + eps * dw + w- = w - eps * dw + hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) + eps = 0.01 / ||dw|| + """ + self._restore_weights(backup_params) + norm = torch.cat([w.view(-1) for w in dw]).norm() + eps = 0.01 / norm + if norm < 1E-8: + logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item()) + + dalphas = [] + for e in [eps, -2. * eps]: + # w+ = w + eps*dw`, w- = w - eps*dw` + with torch.no_grad(): + for p, d in zip(self.model.parameters(), dw): + p += e * d + + _, loss = self._logits_and_loss(trn_X, trn_y) + dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) + + dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } + hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)] + return hessian \ No newline at end of file diff --git a/dubhe-tadl/darts/datasets.py b/dubhe-tadl/darts/datasets.py new file mode 100644 index 0000000..9ddebe7 --- /dev/null +++ b/dubhe-tadl/darts/datasets.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +class Cutout(object): + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + + return img + + +def get_dataset(cls, cutout_length=0, root=None): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + cutout = [] + if cutout_length > 0: + cutout.append(Cutout(cutout_length)) + + train_transform = transforms.Compose(transf + normalize + cutout) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root=root, train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root=root, train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/dubhe-tadl/darts/model.py b/dubhe-tadl/darts/model.py new file mode 100644 index 0000000..bdc05d3 --- /dev/null +++ b/dubhe-tadl/darts/model.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import OrderedDict + +import torch +import torch.nn as nn + +import ops +import pytorch.mutables as mutables + + +class AuxiliaryHead(nn.Module): + """ Auxiliary head in 2/3 place of network to let the gradient flow well """ + + def __init__(self, input_size, C, n_classes): + """ assuming input size 7x7 or 8x8 """ + assert input_size in [7, 8] + super().__init__() + self.net = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out + nn.Conv2d(C, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.linear = nn.Linear(768, n_classes) + + def forward(self, x): + out = self.net(x) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): + super().__init__() + self.ops = nn.ModuleList() + choice_keys = [] + for i in range(num_prev_nodes): + stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) + self.ops.append( + mutables.LayerChoice(OrderedDict([ + ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), + ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), + ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)), + ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)), + ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), + ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), + ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) + ]), key=choice_keys[-1])) + self.drop_path = ops.DropPath() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) + + def forward(self, prev_nodes): + assert len(self.ops) == len(prev_nodes) + out = [op(node) for op, node in zip(self.ops, prev_nodes)] + out = [self.drop_path(o) if o is not None else None for o in out] + return self.input_switch(out) + + +class Cell(nn.Module): + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), + depth, channels, 2 if reduction else 0)) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for node in self.mutable_ops: + cur_tensor = node(tensors) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class CNN(nn.Module): + + def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, + stem_multiplier=3, auxiliary=False): + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + if i == self.aux_pos: + self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + aux_logits = None + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1) + if i == self.aux_pos and self.training: + aux_logits = self.aux_head(s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + + if aux_logits is not None: + return logits, aux_logits + return logits + + def drop_path_prob(self, p): + for module in self.modules(): + if isinstance(module, ops.DropPath): + module.p = p diff --git a/dubhe-tadl/darts/ops.py b/dubhe-tadl/darts/ops.py new file mode 100644 index 0000000..863334e --- /dev/null +++ b/dubhe-tadl/darts/ops.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + + +class DropPath(nn.Module): + def __init__(self, p=0.): + """ + Drop path with probability. + + Parameters + ---------- + p : float + Probability of an path to be zeroed. + """ + super().__init__() + self.p = p + + def forward(self, x): + if self.training and self.p > 0.: + keep_prob = 1. - self.p + # per data point mask + mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob) + return x / keep_prob * mask + + return x + + +class PoolBN(nn.Module): + """ + AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. + """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): + super().__init__() + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + self.bn = nn.BatchNorm2d(C, affine=affine) + + def forward(self, x): + out = self.pool(x) + out = self.bn(out) + return out + + +class StdConv(nn.Module): + """ + Standard conv: ReLU - Conv - BN + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FacConv(nn.Module): + """ + Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN + """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), + nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class DilConv(nn.Module): + """ + (Dilated) depthwise separable conv. + ReLU - (Dilated) depthwise separable - Pointwise - BN. + If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super().__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, + bias=False), + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class SepConv(nn.Module): + """ + Depthwise separable conv. + DilConv(dilation=1) * 2. + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super().__init__() + self.net = nn.Sequential( + DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), + DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) + ) + + def forward(self, x): + return self.net(x) + + +class FactorizedReduce(nn.Module): + """ + Reduce feature map size by factorized pointwise (stride=2). + """ + def __init__(self, C_in, C_out, affine=True): + super().__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out diff --git a/dubhe-tadl/darts/readme.md b/dubhe-tadl/darts/readme.md new file mode 100644 index 0000000..817c190 --- /dev/null +++ b/dubhe-tadl/darts/readme.md @@ -0,0 +1,83 @@ +# train stage +`python darts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --layers 8 --model_lr 0.025 --arch_lr 3e-4 --epochs 1 --batch_size 64 --channels 16` +Note: +here `--epochs 2` just for debug + +# select stage +`python darts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json' ` + +# retrain stage +`python darts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 1 --lr 0.025 --layers 20 --channels 36` + +# output file +`result.json` +``` +{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} +{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} +{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} +``` + +`search_space.json` +``` +{ + "op_list": { + "_type": "layer_choice", + "_value": [ + "maxpool", + "avgpool", + "skipconnect", + "sepconv3x3", + "sepconv5x5", + "dilconv3x3", + "dilconv5x5", + "none" + ] + }, + "search_space": { + "normal_n2_p0": "op_list", + "normal_n2_p1": "op_list", + "normal_n2_switch": { + "_type": "input_choice", + "_value": { + "candidates": [ + "normal_n2_p0", + "normal_n2_p1" + ], + "n_chosen": 2 + } + }, + + ... + } +``` + +`best_selected_space.json` +``` +{ + "normal_n2_p0": "dilconv5x5", + "normal_n2_p1": "dilconv5x5", + "normal_n2_switch": [ + "normal_n2_p0", + "normal_n2_p1" + ], + "normal_n3_p0": "sepconv3x3", + "normal_n3_p1": "dilconv5x5", + "normal_n3_p2": [], + "normal_n3_switch": [ + "normal_n3_p0", + "normal_n3_p1" + ], + "normal_n4_p0": [], + "normal_n4_p1": "dilconv5x5", + "normal_n4_p2": "sepconv5x5", + "normal_n4_p3": [], + "normal_n4_switch": [ + "normal_n4_p1", + "normal_n4_p2" + ], + ... +} +``` \ No newline at end of file diff --git a/dubhe-tadl/darts/utils.py b/dubhe-tadl/darts/utils.py new file mode 100644 index 0000000..8ae80d9 --- /dev/null +++ b/dubhe-tadl/darts/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res \ No newline at end of file diff --git a/dubhe-tadl/dataloader.py b/dubhe-tadl/dataloader.py new file mode 100644 index 0000000..356fde4 --- /dev/null +++ b/dubhe-tadl/dataloader.py @@ -0,0 +1,331 @@ +from nvidia.dali.pipeline import Pipeline +from nvidia.dali import ops +from nvidia.dali import types +from nvidia.dali.plugin.pytorch import DALIClassificationIterator + +import numpy as np +import torch +from torch import nn + + +class HybridTrainPipeline(Pipeline): + + def __init__(self, batch_size, file_root, num_threads, device_id, num_shards, shard_id): + super(HybridTrainPipeline, self).__init__(batch_size, num_threads, device_id) + + device_type = {0:"cpu"} + if num_shards == 0: + self.input = ops.FileReader(file_root = file_root) + else: + self.input = ops.FileReader(file_root = file_root, num_shards = num_shards, shard_id = shard_id) + + # ##### 可自由更改 ################################### + self.decode = ops.ImageDecoder(device = device_type.get(num_shards, "mixed"), output_type = types.RGB) + self.res = ops.RandomResizedCrop(device=device_type.get(num_shards, "gpu"), size = 224) + self.cmnp = ops.CropMirrorNormalize(device=device_type.get(num_shards, "gpu"), + dtype = types.FLOAT, # output_dtype=types.FLOAT, + output_layout=types.NCHW, + mean=0. ,# if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + std=1. )# if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + + # #################################################### + + def define_graph(self, ): + jpegs, labels = self.input(name="Reader") + images = self.decode(jpegs) + images = self.res(images) + images = self.cmnp(images) + return images, labels + + +class HybridValPipeline(Pipeline): + + def __init__(self, batch_size, file_root, num_threads, device_id, num_shards, shard_id): + super(HybridValPipeline, self).__init__(batch_size, num_threads, device_id) + + device_type = {0:"cpu"} + if num_shards == 0: + self.input = ops.FileReader(file_root = file_root) + else: + self.input = ops.FileReader(file_root = file_root, num_shards = num_shards, shard_id = shard_id) + + + + # ##### 可自由更改 ################################### + self.decode = ops.ImageDecoder(device = device_type.get(num_shards, "mixed"), output_type = types.RGB) + self.res = ops.RandomResizedCrop(device=device_type.get(num_shards, "gpu"), size = 224) + self.cmnp = ops.CropMirrorNormalize(device=device_type.get(num_shards, "gpu"), + dtype = types.FLOAT, # output_dtype=types.FLOAT, + output_layout=types.NCHW, + mean=0. ,# if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + std=1. )# if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + + # #################################################### + + def define_graph(self, ): + jpegs, labels = self.input(name="Reader") + images = self.decode(jpegs) + images = self.res(images) + images = self.cmnp(images) + return images, labels + + +class TorchWrapper: + + """ + 将多个pipeline封装为一个iterator + + parameters: + num_shards : int 显卡并行数 + data_loader : dali.pipeline.Pipeline类型 经过pipeline处理的数据结果 + iter_mode : str recursion, iter 指定多个pipeline合并的方式,默认recursion + """ + + + def __init__(self, num_shards, data_loader, iter_mode = "recursion"): + self.index = 0 + self.count = 0 + self.num_shards = num_shards + self.data_loader = data_loader + self.iter_mode = iter_mode + if self.iter_mode not in {"recursion", "iter"}: + raise Exception("iter_mode should be either 'recursion' or 'iter'") + + def __iter__(self,): + return self + + def __len__(self, ): + # 返回样本总量,而非batch_num + if num_shards == 0: + return self.data_loader.size + + else: + return len(self.data_loader)*self.data_loader[0].size + + def __next__(self, ): + if num_shards == 0: + # 不使用GPU + data = next(self.data_loader) + return data[0]["data"], data[0]["label"].view(-1).long() + + else: + # 使用一块或多块GPU + if self.iter_mode == "recursion": + return self._get_next_recursion() + elif self.iter_mode == "iter": + return self._get_next_iter(self.data_loader[0]) + + def _get_next_iter(self, data_loader): + + if self.count == data_loader.size: + self.index+=1 + data_loader = self.data_loader[self.index] + + self.count+=1 + data = next(data_loader) + return data[0]["data"], data[0]["label"].view(-1).long() + + def _get_next_recursion(self, ): + + self.index = self.count%self.num_shards + self.count+=1 + + data_loader = self.data_loader[self.index] + data = next(data_loader) + + return data[0]["data"], data[0]["label"].view(-1).long() + + +def get_iter_dali_cuda(batch_size=256, train_file_root="", val_file_root="", num_threads=4, device_id=[-1], num_shards=0, shard_id=[-1]): + + """ + 获取可用于pytorch训练的数据迭代器 + 数据的读取和处理部分可以使用多张GPU来完成 + + 1、创建dali pipeline + 2、封装为适用于pytorch的数据迭代器 + 3、将多卡的各个pipeline封装在一起 + 4、数据输出在cpu端,在cuda中 + + 数据需要保证如下形式: + images + |-file_list.txt + |-images/dog + |-dog_4.jpg + |-dog_5.jpg + |-dog_9.jpg + |-dog_6.jpg + |-dog_3.jpg + |-images/kitten + |-cat_10.jpg + |-cat_5.jpg + |-cat_9.jpg + |-cat_8.jpg + |-cat_1.jpg + + parameters: + + batch_size : int 每批数据的量 + file_root : str 数据的路径 + num_threads : int 读取数据的CPU线程数 + device_id : list of int GPU的物理编号 + shard_id : list of int GPU的虚拟编号 + num_shard : int + + methods: + + get_train_pipeline(shard_id, device_id) : 创建dali的pipeline,用以读取并处理训练数据 + get_val_pipeline(shard_id, device_id) : 创建dali的pipeline,用以读取并处理验证数据 + get_dali_iter_for_torch(piplines, data_num) : 封装成可用于pytorch的数据迭代器 + get_data_size(pipeline) : 计算每个pipeline实际输出的数据总量,数据总量是文件中的数据量,实际输出是去掉了不满一个批次大小的数据 + + 例: + # 分别从TRAIN_PATH和VAL_PATH读取训练和验证数据,batch_size选择256,启动4个线程来读取数据,用2块GPU处理数据,分别是第0号和第4号GPU + # 程序默认使用所有显卡,和4线程 + # 如果使用单张GPU,请设置num_shards = 1, shard_id = [0], device_id保持一个列表形式 + # 如果不使用GPU,请使用get_iter_dali_cpu() + train_data_iter, val_data_iter = get_iter_dali(batch_size=256, + train_file_root=TRAIN_PATH, + val_file_root=Val_PATH, + num_threads=4, + device_id=[0,4], + num_shards=2, + shard_id=[0,1]) + + # 在torch中训练 + torch_model = TorchModel(para) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(torch_model.parameters()) + + for epoch in range(epoches): + for step, x,y in enumerate(train_data_iter): + + # 数据 : x + # 标签 : y + x = x.to("cuda:0") + y = y.to("cuda:0") + output = my_model(x) + + optimizer.zero_grad() + loss = criterion(output, y) + loss.backward() + optimizer.step() + ... + ... + + """ + + def get_train_pipeline(shard_id, device_id): + + pipeline = HybridTrainPipeline(batch_size = batch_size, + file_root = train_file_root, + num_threads = num_threads, + num_shards = num_shards, + shard_id = shard_id, + device_id = device_id) + return pipeline + + def get_val_pipeline(shard_id, device_id): + + pipeline = HybridValPipeline(batch_size = batch_size, + file_root = val_file_root, + num_threads = num_threads, + num_shards = num_shards, + shard_id = shard_id, + device_id = device_id) + return pipeline + + + + pipeline_for_train = [get_train_pipeline(shard_id = shard_id_index, device_id = device_id_index) \ + for shard_id_index, device_id_index in zip(shard_id, device_id)] + pipeline_for_val = [get_val_pipeline(shard_id = shard_id_index, device_id = device_id_index) \ + for shard_id_index, device_id_index in zip(shard_id, device_id)] + + + [pipeline.build() for pipeline in pipeline_for_train] + [pipeline.build() for pipeline in pipeline_for_val] + + + def get_data_size(pipeline): + data_num = pipeline.epoch_size()["Reader"] + batch_size = pipeline.batch_size + return data_num//batch_size*batch_size + + + data_num_train = get_data_size(pipeline_for_train[0]) + data_num_val = get_data_size(pipeline_for_val[0]) + def get_dali_iter_for_torch(pipelines, data_num): + return [DALIClassificationIterator(pipelines=pipeline, + last_batch_policy="drop",size = data_num) for pipeline in pipelines] + + + data_loader_train = get_dali_iter_for_torch(pipeline_for_train, data_num_train) + data_loader_val = get_dali_iter_for_torch(pipeline_for_val, data_num_val) + + + train_data_iter = TorchWrapper(num_shards, data_loader_train) + val_data_iter = TorchWrapper(num_shards, data_loader_val) + + + return train_data_iter, val_data_iter + + +def get_iter_dali_cpu(batch_size=256, train_file_root="", val_file_root="", num_threads=4): + + pipeline_train = HybridTrainPipeline(batch_size = batch_size, + file_root = train_file_root, + num_threads = num_threads, + num_shards = 0, + shard_id = -1, + device_id = 0) + + + pipeline_val = HybridTrainPipeline(batch_size = batch_size, + file_root = val_file_root, + num_threads = num_threads, + num_shards = 0, + shard_id = -1, + device_id = 0) + + pipeline_train.build() + pipeline_val.build() + + def get_data_size(pipeline): + data_num = pipeline.epoch_size()["Reader"] + batch_size = pipeline.batch_size + return data_num//batch_size*batch_size + + data_num_train = get_data_size(pipeline_train) + data_num_val = get_data_size(pipeline_val) + + data_loader_train = DALIClassificationIterator(pipelines=pipeline_train, + last_batch_policy="drop",size = data_num_train) + data_loader_val = DALIClassificationIterator(pipelines=pipeline_val, + last_batch_policy="drop",size = data_num_val) + + train_data_iter = TorchWrapper(0,data_loader_train) + val_data_iter = TorchWrapper(0,data_loader_val) + + return train_data_iter, val_data_iter + + + +if __name__ == "__main__": + + PATH = "./imagenet" + TRAIN_PATH = "./imagenet/train" + VALID_PATH = "./imagenet/val" + + train_data_iter_cuda, val_data_iter_cuda = get_iter_dali_cuda(batch_size=256, + train_file_root=TRAIN_PATH, + val_file_root=TRAIN_PATH, + num_threads=4, + device_id=[0,4], + num_shards=2, + shard_id=[0,1]) + + train_data_iter_cpu, val_data_iter_cpu = get_iter_dali_cpu(batch_size=256, + train_file_root=TRAIN_PATH, + val_file_root=TRAIN_PATH, + num_threads=4) \ No newline at end of file diff --git a/dubhe-tadl/enas/README.md b/dubhe-tadl/enas/README.md new file mode 100644 index 0000000..dba5f68 --- /dev/null +++ b/dubhe-tadl/enas/README.md @@ -0,0 +1,46 @@ +# Efficient Neural Architecture Search (ENAS) + +## 1. Requirements +``` +torch +torchvision +collections +argparser +pickle +pytest-shutil +``` + +## 2.Train +### Stage1: search an architecture + +* macro search + +``` +python trainer.py --trial_id=0 --search_for macro --best_selected_space_path='./macro_selected_space.json' --result_path='./macro_result.json' +``` + +* micro search + +``` +python trainer.py --trial_id=0 --search_for micro --best_selected_space_path='./micro_selected_space.json' --result_path='./micro_result.json' +``` + +### Stage2: select (deprecated) +``` +python selector.py +``` + +### stage3: retrain +* macro search + +``` +python retrainer.py --search_for macro --best_checkpoint_dir='./macro_checkpoint.pth' --best_selected_space_path= +'./macro_selected_space.json' --result_path='./macro_result.json' +``` + +* micro search + +``` +python retrainer.py --search_for micro --best_checkpoint_dir='./micro_checkpoint.pth' --best_selected_space_path= +'./micro_selected_space.json' --result_path='./micro_result.json' +``` \ No newline at end of file diff --git a/dubhe-tadl/enas/__init__.py b/dubhe-tadl/enas/__init__.py new file mode 100644 index 0000000..d337283 --- /dev/null +++ b/dubhe-tadl/enas/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .mutator import EnasMutator +from .trainer import EnasTrainer diff --git a/dubhe-tadl/enas/datasets.py b/dubhe-tadl/enas/datasets.py new file mode 100644 index 0000000..fd86df2 --- /dev/null +++ b/dubhe-tadl/enas/datasets.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls,datadir): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root=datadir, train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/dubhe-tadl/enas/macro.py b/dubhe-tadl/enas/macro.py new file mode 100644 index 0000000..bcc3a77 --- /dev/null +++ b/dubhe-tadl/enas/macro.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch.nn as nn + +import sys +sys.path.append('..'+ '/' + '..') +from pytorch import mutables # LayerChoice, InputChoice, MutableScope +from ops import FactorizedReduce, ConvBranch, PoolBranch + + +class ENASLayer(mutables.MutableScope): + + def __init__(self, key, prev_labels, in_filters, out_filters): + super().__init__(key) + self.in_filters = in_filters + self.out_filters = out_filters + self.mutable = mutables.LayerChoice([ + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False), + ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False), + ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True), + PoolBranch('avg', in_filters, out_filters, 3, 1, 1), + PoolBranch('max', in_filters, out_filters, 3, 1, 1) + ]) + if len(prev_labels) > 0: + self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) + else: + self.skipconnect = None + self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) + + def forward(self, prev_layers): + out = self.mutable(prev_layers[-1]) + if self.skipconnect is not None: + connection = self.skipconnect(prev_layers[:-1]) + if connection is not None: + out += connection + return self.batch_norm(out) + + +class GeneralNetwork(nn.Module): + def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, + dropout_rate=0.0): + super().__init__() + self.num_layers = num_layers + self.num_classes = num_classes + self.out_filters = out_filters + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_filters) + ) + + pool_distance = self.num_layers // 3 + self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1] + self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(self.dropout_rate) + self.layers = nn.ModuleList() + self.pool_layers = nn.ModuleList() + labels = [] + for layer_id in range(self.num_layers): + labels.append("layer_{}".format(layer_id)) + if layer_id in self.pool_layers_idx: + self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) + self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters)) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(self.out_filters, self.num_classes) + + def forward(self, x): + bs = x.size(0) + cur = self.stem(x) + + layers = [cur] + + for layer_id in range(self.num_layers): + cur = self.layers[layer_id](layers) + layers.append(cur) + if layer_id in self.pool_layers_idx: + for i, layer in enumerate(layers): + layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) + cur = layers[-1] + + cur = self.gap(cur).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + return logits diff --git a/dubhe-tadl/enas/micro.py b/dubhe-tadl/enas/micro.py new file mode 100644 index 0000000..42e2fdb --- /dev/null +++ b/dubhe-tadl/enas/micro.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch import mutables +from ops import FactorizedReduce, StdConv, SepConvBN, Pool + + +class AuxiliaryHead(nn.Module): + def __init__(self, in_channels, num_classes): + super().__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.pooling = nn.Sequential( + nn.ReLU(), + nn.AvgPool2d(5, 3, 2) + ) + self.proj = nn.Sequential( + StdConv(in_channels, 128), + StdConv(128, 768) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(768, 10, bias=False) + + def forward(self, x): + bs = x.size(0) + x = self.pooling(x) + x = self.proj(x) + x = self.avg_pool(x).view(bs, -1) + x = self.fc(x) + return x + + +class Cell(nn.Module): + def __init__(self, cell_name, prev_labels, channels): + super().__init__() + self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, + key=cell_name + "_input") + self.op_choice = mutables.LayerChoice([ + SepConvBN(channels, channels, 3, 1), + SepConvBN(channels, channels, 5, 2), + Pool("avg", 3, 1, 1), + Pool("max", 3, 1, 1), + nn.Identity() + ], key=cell_name + "_op") + + def forward(self, prev_layers): + chosen_input, chosen_mask = self.input_choice(prev_layers) + cell_out = self.op_choice(chosen_input) + return cell_out, chosen_mask + + +class Node(mutables.MutableScope): + def __init__(self, node_name, prev_node_names, channels): + super().__init__(node_name) + self.cell_x = Cell(node_name + "_x", prev_node_names, channels) + self.cell_y = Cell(node_name + "_y", prev_node_names, channels) + + def forward(self, prev_layers): + out_x, mask_x = self.cell_x(prev_layers) + out_y, mask_y = self.cell_y(prev_layers) + return out_x + out_y, mask_x | mask_y + + +class Calibration(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.process = None + if in_channels != out_channels: + self.process = StdConv(in_channels, out_channels) + + def forward(self, x): + if self.process is None: + return x + return self.process(x) + + +class ReductionLayer(nn.Module): + def __init__(self, in_channels_pp, in_channels_p, out_channels): + super().__init__() + self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) + self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) + + def forward(self, pprev, prev): + return self.reduce0(pprev), self.reduce1(prev) + + +class ENASLayer(nn.Module): + def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): + super().__init__() + self.preproc0 = Calibration(in_channels_pp, out_channels) + self.preproc1 = Calibration(in_channels_p, out_channels) + + self.num_nodes = num_nodes + name_prefix = "reduce" if reduction else "normal" + self.nodes = nn.ModuleList() + node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] + for i in range(num_nodes): + node_labels.append("{}_node_{}".format(name_prefix, i)) + self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) + self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) + self.bn = nn.BatchNorm2d(out_channels, affine=False) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.final_conv_w) + + def forward(self, pprev, prev): + pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) + + prev_nodes_out = [pprev_, prev_] + nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) + for i in range(self.num_nodes): + node_out, mask = self.nodes[i](prev_nodes_out) + nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device) + prev_nodes_out.append(node_out) + + unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) + unused_nodes = F.relu(unused_nodes) + conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] + conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1) + out = F.conv2d(unused_nodes, conv_weight) + return prev, self.bn(out) + + +class MicroNetwork(nn.Module): + def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10, + dropout_rate=0.0, use_aux_heads=False): + super().__init__() + self.num_layers = num_layers + self.use_aux_heads = use_aux_heads + + self.stem = nn.Sequential( + nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels * 3) + ) + + pool_distance = self.num_layers // 3 + pool_layers = [pool_distance, 2 * pool_distance + 1] + self.dropout = nn.Dropout(dropout_rate) + + self.layers = nn.ModuleList() + c_pp = c_p = out_channels * 3 + c_cur = out_channels + for layer_id in range(self.num_layers + 2): + reduction = False + if layer_id in pool_layers: + c_cur, reduction = c_p * 2, True + self.layers.append(ReductionLayer(c_pp, c_p, c_cur)) + c_pp = c_p = c_cur + self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction)) + if self.use_aux_heads and layer_id == pool_layers[-1] + 1: + self.layers.append(AuxiliaryHead(c_cur, num_classes)) + c_pp, c_p = c_p, c_cur + + self.gap = nn.AdaptiveAvgPool2d(1) + self.dense = nn.Linear(c_cur, num_classes) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + + def forward(self, x): + bs = x.size(0) + prev = cur = self.stem(x) + aux_logits = None + + for layer in self.layers: + if isinstance(layer, AuxiliaryHead): + if self.training: + aux_logits = layer(cur) + else: + prev, cur = layer(prev, cur) + + cur = self.gap(F.relu(cur)).view(bs, -1) + cur = self.dropout(cur) + logits = self.dense(cur) + + if aux_logits is not None: + return logits, aux_logits + return logits diff --git a/dubhe-tadl/enas/mutator.py b/dubhe-tadl/enas/mutator.py new file mode 100644 index 0000000..ffa75d9 --- /dev/null +++ b/dubhe-tadl/enas/mutator.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch.mutator import Mutator +from pytorch.mutables import LayerChoice, InputChoice, MutableScope + + +class StackedLSTMCell(nn.Module): + def __init__(self, layers, size, bias): + super().__init__() + self.lstm_num_layers = layers + self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias) + for _ in range(self.lstm_num_layers)]) + + def forward(self, inputs, hidden): + prev_c, prev_h = hidden + next_c, next_h = [], [] + for i, m in enumerate(self.lstm_modules): + curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) + next_c.append(curr_c) + next_h.append(curr_h) + # current implementation only supports batch size equals 1, + # but the algorithm does not necessarily have this limitation + inputs = curr_h[-1].view(1, -1) + return next_c, next_h + + +class EnasMutator(Mutator): + """ + A mutator that mutates the graph with RL. + + Parameters + ---------- + model : nn.Module + PyTorch model. + lstm_size : int + Controller LSTM hidden units. + lstm_num_layers : int + Number of layers for stacked LSTM. + tanh_constant : float + Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. + cell_exit_extra_step : bool + If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state + and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper. + skip_target : float + Target probability that skipconnect will appear. + temperature : float + Temperature constant that divides the logits. + branch_bias : float + Manual bias applied to make some operations more likely to be chosen. + Currently this is implemented with a hardcoded match rule that aligns with original repo. + If a mutable has a ``reduce`` in its key, all its op choices + that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others + receive a bias of ``-self.branch_bias``. + entropy_reduction : str + Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced. + """ + + def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, + skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"): + super().__init__(model) + self.lstm_size = lstm_size + self.lstm_num_layers = lstm_num_layers + self.tanh_constant = tanh_constant + self.temperature = temperature + self.cell_exit_extra_step = cell_exit_extra_step + self.skip_target = skip_target + self.branch_bias = branch_bias + + self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) + self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) + self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) + self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) + self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable + assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean." + self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self.bias_dict = nn.ParameterDict() + + self.max_layer_choice = 0 + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + if self.max_layer_choice == 0: + self.max_layer_choice = len(mutable) + assert self.max_layer_choice == len(mutable), \ + "ENAS mutator requires all layer choice have the same number of candidates." + # We are judging by keys and module types to add biases to layer choices. Needs refactor. + if "reduce" in mutable.key: + def is_conv(choice): + return "conv" in str(type(choice)).lower() + bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable + for choice in mutable]) + self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) + + self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) + self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) + + def sample_search(self): + self._initialize() + self._sample(self.mutables) + return self._choices + + def sample_final(self): + return self.sample_search() + + def _sample(self, tree): + mutable = tree.mutable + if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_layer_choice(mutable) + elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: + self._choices[mutable.key] = self._sample_input_choice(mutable) + for child in tree.children: + self._sample(child) + if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: + if self.cell_exit_extra_step: + self._lstm_next_step() + self._mark_anchor(mutable.key) + + def _initialize(self): + self._choices = dict() + self._anchors_hid = dict() + self._inputs = self.g_emb.data + self._c = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self._h = [torch.zeros((1, self.lstm_size), + dtype=self._inputs.dtype, + device=self._inputs.device) for _ in range(self.lstm_num_layers)] + self.sample_log_prob = 0 + self.sample_entropy = 0 + self.sample_skip_penalty = 0 + + def _lstm_next_step(self): + self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) + + def _mark_anchor(self, key): + self._anchors_hid[key] = self._h[-1] + + def _sample_layer_choice(self, mutable): + self._lstm_next_step() + logit = self.soft(self._h[-1]) + if self.temperature is not None: + logit /= self.temperature + if self.tanh_constant is not None: + logit = self.tanh_constant * torch.tanh(logit) + if mutable.key in self.bias_dict: + logit += self.bias_dict[mutable.key] + branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + log_prob = self.cross_entropy_loss(logit, branch_id) + self.sample_log_prob += self.entropy_reduction(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type + self.sample_entropy += self.entropy_reduction(entropy) + self._inputs = self.embedding(branch_id) + return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) + + def _sample_input_choice(self, mutable): + query, anchors = [], [] + for label in mutable.choose_from: + if label not in self._anchors_hid: + self._lstm_next_step() + self._mark_anchor(label) # empty loop, fill not found + query.append(self.attn_anchor(self._anchors_hid[label])) + anchors.append(self._anchors_hid[label]) + query = torch.cat(query, 0) + query = torch.tanh(query + self.attn_query(self._h[-1])) + query = self.v_attn(query) + if self.temperature is not None: + query /= self.temperature + if self.tanh_constant is not None: + query = self.tanh_constant * torch.tanh(query) + + if mutable.n_chosen is None: + logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type + + skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip_prob = torch.sigmoid(logit) + kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) + self.sample_skip_penalty += kl + log_prob = self.cross_entropy_loss(logit, skip) + self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) + else: + assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." + logit = query.view(1, -1) + index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) + skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) + log_prob = self.cross_entropy_loss(logit, index) + self._inputs = anchors[index.item()] + + self.sample_log_prob += self.entropy_reduction(log_prob) + entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type + self.sample_entropy += self.entropy_reduction(entropy) + return skip.bool() diff --git a/dubhe-tadl/enas/ops.py b/dubhe-tadl/enas/ops.py new file mode 100644 index 0000000..59d615a --- /dev/null +++ b/dubhe-tadl/enas/ops.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + + +class StdConv(nn.Module): + def __init__(self, C_in, C_out): + super(StdConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + return self.conv(x) + + def __str__(self): + return 'StdConv' + + +class PoolBranch(nn.Module): + def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): + super().__init__() + self.kernel_size = kernel_size + self.pool_type = pool_type + self.preproc = StdConv(C_in, C_out) + self.pool = Pool(pool_type, kernel_size, stride, padding) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = self.preproc(x) + out = self.pool(out) + out = self.bn(out) + return out + + def __str__(self): + return '{}PoolBranch_{}'.format(self.pool_type, self.kernel_size) + +class SeparableConv(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding): + self.kernel_size = kernel_size + super(SeparableConv, self).__init__() + self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, + groups=C_in, bias=False) + self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + def __str__(self): + return 'SeparableConv_{}'.format(self.kernel_size) + +class ConvBranch(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): + super(ConvBranch, self).__init__() + self.kernel_size = kernel_size + self.preproc = StdConv(C_in, C_out) + if separable: + self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) + else: + self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) + self.postproc = nn.Sequential( + nn.BatchNorm2d(C_out, affine=False), + nn.ReLU() + ) + + def forward(self, x): + out = self.preproc(x) + out = self.conv(out) + out = self.postproc(out) + return out + + def __str__(self): + return 'ConvBranch_{}'.format(self.kernel_size) + +class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, affine=False): + super().__init__() + self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out + + def __str__(self): + return 'FactorizedReduce' + +class Pool(nn.Module): + def __init__(self, pool_type, kernel_size, stride, padding): + super().__init__() + self.kernel_size = kernel_size + self.pool_type = pool_type + if pool_type.lower() == 'max': + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + elif pool_type.lower() == 'avg': + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + else: + raise ValueError() + + def forward(self, x): + return self.pool(x) + + def __str__(self): + return '{}Pool_{}'.format(self.pool_type, self.kernel_size) + +class SepConvBN(nn.Module): + def __init__(self, C_in, C_out, kernel_size, padding): + super().__init__() + self.kernel_size = kernel_size + self.relu = nn.ReLU() + self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding) + self.bn = nn.BatchNorm2d(C_out, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.conv(x) + x = self.bn(x) + return x + + def __str__(self): + return 'SepConvBN_{}'.format(self.kernel_size) \ No newline at end of file diff --git a/dubhe-tadl/enas/retrain.py b/dubhe-tadl/enas/retrain.py new file mode 100644 index 0000000..2beaf10 --- /dev/null +++ b/dubhe-tadl/enas/retrain.py @@ -0,0 +1,490 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +import pickle +import shutil +import random +import math +import time +import datetime +import argparse +import distutils.util +import numpy as np +import json +import torch +from torch import nn +from torch import optim +from torch.utils.data import DataLoader +import torch.nn.functional as Func + +from macro import GeneralNetwork +from micro import MicroNetwork +import datasets +from utils import accuracy, reward_accuracy +from pytorch.fixed import apply_fixed_architecture +from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint + +logger = logging.getLogger("enas-retrain") + +# TODO: +def set_random_seed(seed): + logger.info("set random seed for data reading: {}".format(seed)) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed_all(seed) + if FLAGS.is_cuda: + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +# TODO: parser args +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="./data", + help="Directory containing the dataset and embedding file. (default: %(default)s)") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search_space directory") + parser.add_argument( + "--selected_space_path", + type=str, + default="./selected_space.json", + # required=True, + help="Architecture json file. (default: %(default)s)") + parser.add_argument("--result_path", type=str, + default='./result.json', help="res directory") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument( + "--output_dir", + type=str, + default="./output", + help="The output directory. (default: %(default)s)") + parser.add_argument( + "--best_checkpoint_dir", + type=str, + default="best_checkpoint", + help="Path for saved checkpoints. (default: %(default)s)") + parser.add_argument("--search_for", + choices=["macro", "micro"], + default="micro") + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="Number of samples each batch for training. (default: %(default)s)") + parser.add_argument( + "--eval_batch_size", + type=int, + default=128, + help="Number of samples each batch for evaluation. (default: %(default)s)") + parser.add_argument( + "--class_num", + type=int, + default=10, + help="The number of categories. (default: %(default)s)") + parser.add_argument( + "--epochs", + type=int, + default=10, + help="The number of training epochs. (default: %(default)s)") + parser.add_argument( + "--child_lr", + type=float, + default=0.02, + help="The initial learning rate. (default: %(default)s)") + parser.add_argument( + "--is_cuda", + type=distutils.util.strtobool, + default=True, + help="Specify the device type. (default: %(default)s)") + parser.add_argument( + "--load_checkpoint", + type=distutils.util.strtobool, + default=False, + help="Whether to load checkpoint. (default: %(default)s)") + parser.add_argument( + "--log_every", + type=int, + default=50, + help="How many steps to log. (default: %(default)s)") + parser.add_argument( + "--eval_every_epochs", + type=int, + default=1, + help="How many epochs to eval. (default: %(default)s)") + parser.add_argument( + "--child_grad_bound", + type=float, + default=5.0, + help="The threshold for gradient clipping. (default: %(default)s)") # + parser.add_argument( + "--child_lr_decay_scheme", + type=str, + default="cosine", + help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove + parser.add_argument( + "--child_lr_T_0", + type=int, + default=10, + help="The length of one cycle. (default: %(default)s)") # todo: use for + parser.add_argument( + "--child_lr_T_mul", + type=int, + default=2, + help="The multiplication factor per cycle. (default: %(default)s)") # todo: use for + parser.add_argument( + "--child_l2_reg", + type=float, + default=3e-6, + help="Weight decay factor. (default: %(default)s)") + parser.add_argument( + "--child_lr_max", + type=float, + default=0.002, + help="The max learning rate. (default: %(default)s)") + parser.add_argument( + "--child_lr_min", + type=float, + default=0.001, + help="The min learning rate. (default: %(default)s)") + parser.add_argument( + "--multi_path", + type=distutils.util.strtobool, + default=False, + help="Search for multiple path in the architecture. (default: %(default)s)") # todo: use for + parser.add_argument( + "--is_mask", + type=distutils.util.strtobool, + default=True, + help="Apply mask. (default: %(default)s)") + global FLAGS + FLAGS = parser.parse_args() + + +def print_user_flags(FLAGS, line_limit=80): + log_strings = "\n" + "-" * line_limit + "\n" + for flag_name in sorted(vars(FLAGS)): + value = "{}".format(getattr(FLAGS, flag_name)) + log_string = flag_name + log_string += "." * (line_limit - len(flag_name) - len(value)) + log_string += value + log_strings = log_strings + log_string + log_strings = log_strings + "\n" + log_strings += "-" * line_limit + logger.info(log_strings) + +def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None): + if eval_set == "test": + assert test_dataloader is not None + dataloader = test_dataloader + elif eval_set == "valid": + assert valid_dataloader is not None + dataloader = valid_dataloader + else: + raise NotImplementedError("Unknown eval_set '{}'".format(eval_set)) + + tot_acc = 0 + tot = 0 + losses = [] + + with torch.no_grad(): # save memory + for batch in dataloader: + + x, y = batch + x, y = to_device(x, device), to_device(y, device) + logits = child_model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = criterion(aux_logits, y) + else: + aux_loss = 0. + + loss = criterion(logits, y) + loss = loss + aux_weight * aux_loss + # loss = loss.mean() + preds = logits.argmax(dim=1).long() + acc = torch.eq(preds, y.long()).long().sum().item() + + losses.append(loss) + tot_acc += acc + tot += len(y) + + losses = torch.tensor(losses) + loss = losses.mean() + if tot > 0: + final_acc = float(tot_acc) / tot + else: + final_acc = 0 + logger.info("Error in calculating final_acc") + return final_acc, loss + +# TODO: learning rate scheduler +def update_lr( + optimizer, + epoch, + l2_reg=1e-4, + lr_warmup_val=None, + lr_init=0.1, + lr_decay_scheme="cosine", + lr_max=0.002, + lr_min=0.000000001, + lr_T_0=4, + lr_T_mul=1, + sync_replicas=False, + num_aggregate=None, + num_replicas=None): + if lr_decay_scheme == "cosine": + assert lr_max is not None, "Need lr_max to use lr_cosine" + assert lr_min is not None, "Need lr_min to use lr_cosine" + assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine" + assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine" + + T_i = lr_T_0 + t_epoch = epoch + last_reset = 0 + while True: + t_epoch -= T_i + if t_epoch < 0: + break + last_reset += T_i + T_i *= lr_T_mul + + T_curr = epoch - last_reset + + def _update(): + rate = T_curr / T_i * 3.1415926 + lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate)) + return lr + + learning_rate = _update() + else: + raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme)) + + #update lr in optimizer + for params_group in optimizer.param_groups: + params_group['lr'] = learning_rate + return learning_rate + +def train(device, output_dir='./output'): + workers = 4 + data = 'cifar10' + + data_dir = FLAGS.data_dir + output_dir = FLAGS.output_dir + checkpoint_dir = FLAGS.best_checkpoint_dir + batch_size = FLAGS.batch_size + eval_batch_size = FLAGS.eval_batch_size + class_num = FLAGS.class_num + epochs = FLAGS.epochs + child_lr = FLAGS.child_lr + is_cuda = FLAGS.is_cuda + load_checkpoint = FLAGS.load_checkpoint + log_every = FLAGS.log_every + eval_every_epochs = FLAGS.eval_every_epochs + + child_grad_bound = FLAGS.child_grad_bound + child_l2_reg = FLAGS.child_l2_reg + + logger.info("Build dataloader") + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + n_train = len(dataset_train) + split = n_train // 10 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) + train_dataloader = torch.utils.data.DataLoader(dataset_train, + batch_size=batch_size, + sampler=train_sampler, + num_workers=workers) + valid_dataloader = torch.utils.data.DataLoader(dataset_train, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=workers) + test_dataloader = torch.utils.data.DataLoader(dataset_valid, + batch_size=batch_size, + num_workers=workers) + + + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(child_model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4, nesterov=True) + # optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) + # TODO + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.001) + + # move model to CPU/GPU device + child_model.to(device) + criterion.to(device) + + logger.info('Start training') + start_time = time.time() + step = 0 + + # save path + if not os.path.exists(output_dir): + os.mkdir(output_dir) + # model_save_path = os.path.join(output_dir, "model.pth") + # best_model_save_path = os.path.join(output_dir, "best_model.pth") + best_acc = 0 + start_epoch = 0 + + # TODO: load checkpoints + + # train + for epoch in range(start_epoch, epochs): + lr = update_lr(optimizer, + epoch, + l2_reg= 1e-4, + lr_warmup_val=None, + lr_init=FLAGS.child_lr, + lr_decay_scheme=FLAGS.child_lr_decay_scheme, + lr_max=0.05, + lr_min=0.001, + lr_T_0=10, + lr_T_mul=2) + child_model.train() + for batch in train_dataloader: + step += 1 + + x, y = batch + x, y = to_device(x, device), to_device(y, device) + logits = child_model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = criterion(aux_logits, y) + else: + aux_loss = 0. + + acc = accuracy(logits, y) + loss = criterion(logits, y) + loss = loss + aux_weight * aux_loss + + optimizer.zero_grad() + loss.backward() + grad_norm = 0 + trainable_params = child_model.parameters() + + for param in trainable_params: + nn.utils.clip_grad_norm_(param, child_grad_bound) # clip grad + + optimizer.step() + + if step % log_every == 0: + curr_time = time.time() + log_string = "" + log_string += "epoch={:<6d}".format(epoch) + log_string += "ch_step={:<6d}".format(step) + log_string += " loss={:<8.6f}".format(loss) + log_string += " lr={:<8.4f}".format(lr) + log_string += " |g|={:<8.4f}".format(grad_norm) + log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) + log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60) + logger.info(log_string) + + epoch += 1 + save_state = { + 'step': step, + 'epoch': epoch, + 'child_model_state_dict': child_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()} + # print(' Epoch {:<3d} loss: {:<.2f} '.format(epoch, loss)) + # torch.save(save_state, model_save_path) + child_model.eval() + logger.info("Epoch {}: Eval".format(epoch)) + eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader) + logger.info( + "ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss)) + if eval_acc > best_acc: + best_acc = eval_acc + logger.info("Save best model") + # save_state = { + # 'step': step, + # 'epoch': epoch, + # 'child_model_state_dict': child_model.state_dict(), + # 'optimizer_state_dict': optimizer.state_dict()} + # torch.save(save_state, best_model_save_path) + save_best_checkpoint(checkpoint_dir, child_model, optimizer, epoch) + + result['accuracy'].append('Epoch {} acc: {:<6.4f}'.format(epoch, eval_acc,)) + + acc_l.append(eval_acc) + + print(result['accuracy'][-1]) + + print('max acc %.4f at epoch: %i'%(max(acc_l), np.argmax(np.array(acc_l)))) + print('Time cost: %.4f hours'%( float(time.time() - start_time) /3600. )) + return result + +# macro = True +parse_args() +child_fixed_arc = FLAGS.selected_space_path # './macro_seletced_space' +search_for = FLAGS.search_for +# 设置随机种子 +torch.manual_seed(FLAGS.trial_id) +torch.cuda.manual_seed_all(FLAGS.trial_id) +np.random.seed(FLAGS.trial_id) +random.seed(FLAGS.trial_id) + +aux_weight = 0.4 +result = {'accuracy':[]} +acc_l = [] + +# decode human readable search space to model +def convert_selected_space_format(): + # with open('./macro_selected_space.json') as js: + with open(child_fixed_arc) as js: + selected_space = json.load(js) + + ops = selected_space['op_list'] + selected_space.pop('op_list') + new_selected_space = {} + + for key, value in selected_space.items(): + # for macro + if FLAGS.search_for == 'macro': + new_key = key.split('_')[-1] + # for micro + elif FLAGS.search_for == 'micro': + new_key = key + + if len(value) > 1 or len(value)==0: + new_value = value + elif len(value) > 0 and value[0] in ops: + new_value = ops.index(value[0]) + else: + new_value = value[0] + new_selected_space[new_key] = new_value + return new_selected_space + +fixed_arc = convert_selected_space_format() +# TODO : macro search or micro search +if FLAGS.search_for == 'macro': + child_model = GeneralNetwork() +elif FLAGS.search_for == 'micro': + child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + +apply_fixed_architecture(child_model,fixed_arc) + +def dump_global_result(res_path,global_result, sort_keys = False): + with open(res_path, "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) + + +def main(): + os.environ['CUDA_VISIBLE_DEVICES'] = '4' + # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("cuda" if FLAGS.is_cuda else "cpu") + train(device) + dump_global_result('result_retrain.json', result['accuracy']) + +if __name__ == "__main__": + main() + diff --git a/dubhe-tadl/enas/retrainer.py b/dubhe-tadl/enas/retrainer.py new file mode 100644 index 0000000..7c5ec36 --- /dev/null +++ b/dubhe-tadl/enas/retrainer.py @@ -0,0 +1,495 @@ +import sys +from utils import accuracy +import torch +from torch import nn +from torch.utils.data import DataLoader +import datasets +import time +import logging +import os +import argparse +import distutils.util +import numpy as np +import json +import random + +sys.path.append('..'+ '/' + '..') + +# import custom packages +from macro import GeneralNetwork +from micro import MicroNetwork +from pytorch.fixed import apply_fixed_architecture +from pytorch.retrainer import Retrainer +from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint, mkdirs + +class EnasRetrainer(Retrainer): + """ + ENAS retrainer. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + data_dir : dataset path + The path of the dataset. + best_checkpoint_dir: 'best_checkpoint.pth' + The directory for saving model. + batch_size : int + Batch size. + eval_batch_size : int + Batch size. + num_epochs : int + Number of epochs planned for training. + lr : float + Learning rate. + is_cuda: Boolean + Whether to use GPU for training. + log_every : int + Step count per logging. + child_grad_bound : float + Gradient bound. + child_l2_reg: float + L2 regression. + eval_every_epochs: int + Evaluate every epochs. + logger: + logging. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + aux_weight : float + Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. + """ + def __init__(self,model,data_dir = './data',best_checkpoint_dir = './best_checkpoint', + batch_size = 1024, eval_batch_size = 1024,num_epochs = 2,lr = 0.02,is_cuda = 'True', + log_every = 40,child_grad_bound = 0.5, child_l2_reg=3e-6, eval_every_epochs=2, + logger = logging.getLogger("enas-retrain"), result_path='./'): + self.aux_weight = 0.4 + self.device = torch.device("cuda:0" ) + self.workers = 4 + + self.child_model = model + self.data_dir = data_dir + self.best_checkpoint_dir = best_checkpoint_dir + self.batch_size = batch_size + self.eval_batch_size = eval_batch_size + self.num_epochs = num_epochs + self.lr = lr + self.is_cuda = is_cuda + self.log_every = log_every + self.child_grad_bound = child_grad_bound + self.child_l2_reg = child_l2_reg + self.eval_every_epochs = eval_every_epochs + self.logger = logger + + self.optimizer = torch.optim.SGD(self.child_model.parameters(), self.lr, momentum=0.9, weight_decay=1.0E-4, nesterov=True) + self.criterion = nn.CrossEntropyLoss() + self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs, eta_min=0.001) + + # load dataset + self.init_dataloader() + self.child_model.to(self.device) + self.result_path = result_path + with open(self.result_path, "w") as file: + file.write('') + + def train(self): + """ + Train ``num_epochs``. + Trigger callbacks at the start and the end of each epoch. + + Parameters + ---------- + validate : bool + If ``true``, will do validation every epoch. + """ + self.logger.info('** Start training **') + + self.start_time = time.time() + for epoch in range(self.num_epochs): + + self.train_one_epoch(epoch) + + self.child_model.eval() + + # if epoch / self.eval_every_epochs == 0: + self.logger.info("Epoch {}: Eval".format(epoch)) + self.validate_one_epoch(epoch) + + self.lr_scheduler.step() + + # print('** saving model **') + + self.logger.info("** Save best model **") + # save_state = { + # 'epoch': epoch, + # 'child_model_state_dict': self.child_model.state_dict(), + # 'optimizer_state_dict': self.optimizer.state_dict()} + # torch.save(save_state, self.best_checkpoint_dir) + save_best_checkpoint(self.best_checkpoint_dir, self.child_model, self.optimizer, epoch) + + def validate(self): + """ + Do one validation. Validate one epoch. + """ + pass + + def export(self, file): + """ + dump the architecture to ``file``. + + Parameters + ---------- + file : str + File path to export to. Expected to be a JSON. + """ + pass + + def checkpoint(self): + """ + Override to dump a checkpoint. + """ + pass + + def init_dataloader(self): + self.logger.info("Build dataloader") + self.dataset_train, self.dataset_valid = datasets.get_dataset("cifar10", self.data_dir) + n_train = len(self.dataset_train) + split = n_train // 10 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=self.batch_size, + sampler=train_sampler, + num_workers=self.workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=self.eval_batch_size, + sampler=valid_sampler, + num_workers=self.workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=self.batch_size, + num_workers=self.workers) + # self.train_loader = cycle(self.train_loader) + # self.valid_loader = cycle(self.valid_loader) + + def train_one_epoch(self,epoch): + """ + Train one epoch. + + Parameters + ---------- + epoch : int + Epoch number starting from 0. + """ + tot_acc = 0 + tot = 0 + losses = [] + step = 0 + self.child_model.train() + meters = AverageMeterGroup() + + for batch in self.train_loader: + step += 1 + + x, y = batch + x, y = to_device(x, self.device), to_device(y, self.device) + + logits = self.child_model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = self.criterion(aux_logits, y) + else: + aux_loss = 0. + + acc = accuracy(logits, y) + loss = self.criterion(logits, y) + loss = loss + self.aux_weight * aux_loss + + self.optimizer.zero_grad() + loss.backward() + grad_norm = 0 + trainable_params = self.child_model.parameters() + + # assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients." + # # compute the gradient norm value + # grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999) + # for param in trainable_params: + # nn.utils.clip_grad_norm_(param, self.child_grad_bound) # clip grad + # print(param_ == param) + if self.child_grad_bound is not None: + grad_norm = nn.utils.clip_grad_norm_(trainable_params, self.child_grad_bound) + trainable_params = grad_norm + + self.optimizer.step() + + tot_acc += acc['acc1'] + tot += 1 + losses.append(loss) + acc["loss"] = loss.item() + meters.update(acc) + + if step % self.log_every == 0: + curr_time = time.time() + log_string = "" + log_string += "epoch={:<6d}".format(epoch) + log_string += "ch_step={:<6d}".format(step) + log_string += " loss={:<8.6f}".format(loss) + log_string += " lr={:<8.4f}".format(self.optimizer.param_groups[0]['lr']) + log_string += " |g|={:<8.4f}".format(grad_norm) + log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) + log_string += " mins={:<10.2f}".format(float(curr_time - self.start_time) / 60) + self.logger.info(log_string) + + print("Model Epoch [%d/%d] %.3f mins %s \n " % (epoch + 1, + self.num_epochs, float(time.time() - self.start_time) / 60, meters )) + final_acc = float(tot_acc) / tot + + losses = torch.tensor(losses) + loss = losses.mean() + + + def validate_one_epoch(self,epoch): + tot_acc = 0 + tot = 0 + losses = [] + meters = AverageMeterGroup() + + with torch.no_grad(): # save memory + meters = AverageMeterGroup() + for batch in self.valid_loader: + x, y = batch + x, y = to_device(x, self.device), to_device(y, self.device) + logits = self.child_model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = self.criterion(aux_logits, y) + else: + aux_loss = 0. + + loss = self.criterion(logits, y) + loss = loss + self.aux_weight * aux_loss + # loss = loss.mean() + preds = logits.argmax(dim=1).long() + acc = torch.eq(preds, y.long()).long().sum().item() + acc_v = accuracy(logits, y) + + losses.append(loss) + tot_acc += acc + tot += len(y) + + acc_v["loss"] = loss.item() + meters.update(acc_v) + + losses = torch.tensor(losses) + loss = losses.mean() + if tot > 0: + final_acc = float(tot_acc) / tot + else: + final_acc = 0 + self.logger.info("Error in calculating final_acc") + + with open(self.result_path, "a") as file: + file.write( + str({"type": "Accuracy", + "result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) + '\n') + + # print("Model eval %.3fmins %s \n " % ( + # float(time.time() - self.start_time) / 60, meters )) + print({"type": "Accuracy", + "result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) + + self.logger.info( + "ch_step= {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format( "test", final_acc, "test", loss)) + + +logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', + level=logging.INFO, + filename='./retrain.log', + filemode='a') +logger = logging.getLogger("enas-retrain") + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="./data", + help="Directory containing the dataset and embedding file. (default: %(default)s)") + parser.add_argument( + "--model_selected_space_path", + type=str, + default="./model_selected_space.json", + # required=True, + help="Architecture json file. (default: %(default)s)") + parser.add_argument("--result_path", type=str, + default='./result.json', help="res directory") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search_space directory") + parser.add_argument("--log_path", type=str, default='output/log') + parser.add_argument( + "--best_selected_space_path", + type=str, + default="./best_selected_space.json", + # required=True, + help="Best architecture selected json file by experiment. (default: %(default)s)") + parser.add_argument( + "--best_checkpoint_dir", + type=str, + default="best_checkpoint", + help="Path for saved checkpoints. (default: %(default)s)") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--search_for", + choices=["macro", "micro"], + default="macro") + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="Number of samples each batch for training. (default: %(default)s)") + parser.add_argument( + "--eval_batch_size", + type=int, + default=128, + help="Number of samples each batch for evaluation. (default: %(default)s)") + parser.add_argument( + "--epochs", + type=int, + default=10, + help="The number of training epochs. (default: %(default)s)") + parser.add_argument( + "--lr", + type=float, + default=0.02, + help="The initial learning rate. (default: %(default)s)") + parser.add_argument( + "--is_cuda", + type=distutils.util.strtobool, + default=True, + help="Specify the device type. (default: %(default)s)") + parser.add_argument( + "--load_checkpoint", + type=distutils.util.strtobool, + default=False, + help="Whether to load checkpoint. (default: %(default)s)") + parser.add_argument( + "--log_every", + type=int, + default=50, + help="How many steps to log. (default: %(default)s)") + parser.add_argument( + "--eval_every_epochs", + type=int, + default=1, + help="How many epochs to eval. (default: %(default)s)") + parser.add_argument( + "--child_grad_bound", + type=float, + default=5.0, + help="The threshold for gradient clipping. (default: %(default)s)") # + parser.add_argument( + "--child_l2_reg", + type=float, + default=3e-6, + help="Weight decay factor. (default: %(default)s)") + parser.add_argument( + "--child_lr_decay_scheme", + type=str, + default="cosine", + help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove + global FLAGS + FLAGS = parser.parse_args() + +# decode human readable search space to model +def convert_selected_space_format(child_fixed_arc): + # with open('./macro_selected_space.json') as js: + with open(child_fixed_arc) as js: + selected_space = json.load(js) + + ops = selected_space['op_list'] + selected_space.pop('op_list') + new_selected_space = {} + + for key, value in selected_space.items(): + # for macro + if FLAGS.search_for == 'macro': + new_key = key.split('_')[-1] + # for micro + elif FLAGS.search_for == 'micro': + new_key = key + + if len(value) > 1 or len(value)==0: + new_value = value + elif len(value) > 0 and value[0] in ops: + new_value = ops.index(value[0]) + else: + new_value = value[0] + new_selected_space[new_key] = new_value + + return new_selected_space + +def set_random_seed(seed): + logger.info("set random seed for data reading: {}".format(seed)) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + if FLAGS.is_cuda: + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def main(): + + parse_args() + + child_fixed_arc = FLAGS.best_selected_space_path # './macro_seletced_space' + search_for = FLAGS.search_for + + # set seed to result todo: trial ID + set_random_seed(FLAGS.trial_id) + + mkdirs(FLAGS.result_path, FLAGS.log_path, FLAGS.best_checkpoint_dir) + # define and load model + logger.info('** ' + FLAGS.search_for + 'search **') + fixed_arc = convert_selected_space_format(child_fixed_arc) + # Model, macro search or micro search + if FLAGS.search_for == 'macro': + child_model = GeneralNetwork() + elif FLAGS.search_for == 'micro': + child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + + apply_fixed_architecture(child_model, fixed_arc) + + # load model + if FLAGS.load_checkpoint: + print('** Load model **') + logger.info('** Load model **') + child_model.load_state_dict(torch.load(FLAGS.best_checkpoint_dir)['child_model_state_dict']) + + retrainer = EnasRetrainer(model=child_model, + data_dir = FLAGS.data_dir, + best_checkpoint_dir=FLAGS.best_checkpoint_dir, + batch_size=FLAGS.batch_size, + eval_batch_size=FLAGS.eval_batch_size, + num_epochs=FLAGS.epochs, + lr=FLAGS.lr, + is_cuda=FLAGS.is_cuda, + log_every=FLAGS.log_every, + child_grad_bound=FLAGS.child_grad_bound, + child_l2_reg=FLAGS.child_l2_reg, + eval_every_epochs=FLAGS.eval_every_epochs, + logger=logger, + result_path=FLAGS.result_path, + ) + + t1 = time.time() + retrainer.train() + print('cost time for retrain: ' , time.time() - t1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dubhe-tadl/enas/search.py b/dubhe-tadl/enas/search.py new file mode 100644 index 0000000..6a49540 --- /dev/null +++ b/dubhe-tadl/enas/search.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import sys +sys.path.append('..'+ '/' + '..') +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +import datasets +from macro import GeneralNetwork +from micro import MicroNetwork +from trainer import EnasTrainer +from mutator import EnasMutator +from pytorch.callbacks import (ArchitectureCheckpoint, + LRSchedulerCallback) +from utils import accuracy, reward_accuracy +from collections import OrderedDict +from pytorch.mutables import LayerChoice, InputChoice +import json +torch.cuda.set_device(4) + +logger = logging.getLogger('tadl-enas') + +# save search space as search_space.json +def save_nas_search_space(mutator,file_path): + result = OrderedDict() + cur_layer_idx = None + for mutable in mutator.mutables.traverse(): + if not isinstance(mutable,(LayerChoice, InputChoice)): + cur_layer_idx = mutable.key + '_' + continue + # macro + if 'layer' in cur_layer_idx: + if isinstance(mutable, LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + result[cur_layer_idx + mutable.key] = 'op_list' + else: + result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, + 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', + 'choose_from': mutable.choose_from if mutable.choose_from else ''} + # micro + elif 'node' in cur_layer_idx: + if isinstance(mutable,LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + result[mutable.key] = 'op_list' + else: + result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, + 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', + 'choose_from': mutable.choose_from if mutable.choose_from else ''} + + dump_global_result(file_path,result) + +# def dump_global_result(args,global_result): +# with open(args['result_path'], "w") as ss_file: +# json.dump(global_result, ss_file, sort_keys=True, indent=2) + +def dump_global_result(res_path,global_result, sort_keys = False): + with open(res_path, "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) + + + +if __name__ == "__main__": + parser = ArgumentParser("enas") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search_space directory") + parser.add_argument("--selected_space_path", type=str, + default='./selected_space.json', help="sapce_path_out directory") + parser.add_argument("--result_path", type=str, + default='./result.json', help="res directory") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=10, type=int) + parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") + parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") + args = parser.parse_args() + + # 设置随机种子 + torch.manual_seed(args.trial_id) + torch.cuda.manual_seed_all(args.trial_id) + np.random.seed(args.trial_id) + random.seed(args.trial_id) + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + if args.search_for == "macro": + model = GeneralNetwork() + num_epochs = args.epochs or 310 + mutator = None + mutator = EnasMutator(model) + elif args.search_for == "micro": + model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + num_epochs = args.epochs or 150 + mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) + else: + raise AssertionError + + # 储存整个网络结构 + # args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) + # print( args.search_space_path, args.search_for ) + save_nas_search_space(mutator, args.search_space_path) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) + + trainer = EnasTrainer(model, + loss=criterion, + metrics=accuracy, + reward_function=reward_accuracy, + optimizer=optimizer, + callbacks=[LRSchedulerCallback(lr_scheduler)], + batch_size=args.batch_size, + num_epochs=num_epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + log_frequency=args.log_frequency, + mutator=mutator, + child_model_path='./'+args.search_for+'_child_model') + + logger.info(trainer.metrics) + + t1 = time.time() + trainer.train() + trainer.result["cost_time"] = time.time() - t1 + dump_global_result(args.result_path,trainer.result) + + selected_model = trainer.export_child_model(selected_space = True) + dump_global_result(args.selected_space_path,selected_model) \ No newline at end of file diff --git a/dubhe-tadl/enas/selector.py b/dubhe-tadl/enas/selector.py new file mode 100644 index 0000000..6275b2a --- /dev/null +++ b/dubhe-tadl/enas/selector.py @@ -0,0 +1,18 @@ +import sys +sys.path.append('../..') +from pytorch.selector import Selector + +class EnasSelector(Selector): + def __init__(self, *args, single_candidate=True): + super().__init__(single_candidate) + self.args = args + + def fit(self): + """ + only one candatite, function passed + """ + pass + +if __name__ == "__main__": + hpo_selector = EnasSelector() + hpo_selector.fit() \ No newline at end of file diff --git a/dubhe-tadl/enas/trainer.py b/dubhe-tadl/enas/trainer.py new file mode 100644 index 0000000..8477901 --- /dev/null +++ b/dubhe-tadl/enas/trainer.py @@ -0,0 +1,436 @@ +from itertools import cycle +import os +import sys +sys.path.append('..'+ '/' + '..') +import numpy as np +import random +import logging +import time +from argparse import ArgumentParser +from collections import OrderedDict +import json +import torch +import torch.nn as nn +import torch.optim as optim + + +# import custom libraries +import datasets +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup, to_device, mkdirs +from pytorch.mutables import LayerChoice, InputChoice, MutableScope +from macro import GeneralNetwork +from micro import MicroNetwork +# from trainer import EnasTrainer +from mutator import EnasMutator +from pytorch.callbacks import (ArchitectureCheckpoint, + LRSchedulerCallback) +from utils import accuracy, reward_accuracy + +torch.cuda.set_device(0) + +logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', + level=logging.INFO, + filename='./train.log', + filemode='a') +logger = logging.getLogger('enas_train') + +class EnasTrainer(Trainer): + """ + ENAS trainer. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + reward_function : callable + Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. + optimizer : Optimizer + The optimizer used for optimizing the model. + num_epochs : int + Number of epochs planned for training. + dataset_train : Dataset + Dataset for training. Will be split for training weights and architecture weights. + dataset_valid : Dataset + Dataset for testing. + mutator : EnasMutator + Use when customizing your own mutator or a mutator with customized parameters. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + callbacks : list of Callback + list of callbacks to trigger at events. + entropy_weight : float + Weight of sample entropy loss. + skip_weight : float + Weight of skip penalty loss. + baseline_decay : float + Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. + child_steps : int + How many mini-batches for model training per epoch. + mutator_lr : float + Learning rate for RL controller. + mutator_steps_aggregate : int + Number of steps that will be aggregated into one mini-batch for RL controller. + mutator_steps : int + Number of mini-batches for each epoch of RL controller learning. + aux_weight : float + Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. + test_arc_per_epoch : int + How many architectures are chosen for direct test after each epoch. + """ + def __init__(self, model, loss, metrics, reward_function, + optimizer, num_epochs, dataset_train, dataset_valid, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, + entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, + mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, + test_arc_per_epoch=1,child_model_path = './', result_path='./'): + super().__init__(model, mutator if mutator is not None else EnasMutator(model), + loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, + batch_size, workers, device, log_frequency, callbacks) + self.reward_function = reward_function + self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) + self.batch_size = batch_size + self.workers = workers + + self.entropy_weight = entropy_weight + self.skip_weight = skip_weight + self.baseline_decay = baseline_decay + self.baseline = 0. + self.mutator_steps_aggregate = mutator_steps_aggregate + self.mutator_steps = mutator_steps + self.child_steps = child_steps + self.aux_weight = aux_weight + self.test_arc_per_epoch = test_arc_per_epoch + self.child_model_path = child_model_path # saving the child model + self.init_dataloader() + # self.result = {'accuracy':[], + # 'cost_time':0} + self.result_path = result_path + with open(self.result_path, "w") as file: + file.write('') + + def init_dataloader(self): + n_train = len(self.dataset_train) + split = n_train // 10 + indices = list(range(n_train)) + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) + self.train_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=self.batch_size, + sampler=train_sampler, + num_workers=self.workers) + self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, + batch_size=self.batch_size, + sampler=valid_sampler, + num_workers=self.workers) + self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, + batch_size=self.batch_size, + num_workers=self.workers) + self.train_loader = cycle(self.train_loader) + self.valid_loader = cycle(self.valid_loader) + + def train_one_epoch(self, epoch): + # Sample model and train + self.model.train() + self.mutator.eval() + meters = AverageMeterGroup() + for step in range(1, self.child_steps + 1): + x, y = next(self.train_loader) + x, y = to_device(x, self.device), to_device(y, self.device) + self.optimizer.zero_grad() + + with torch.no_grad(): + self.mutator.reset() + # self._write_graph_status() + logits = self.model(x) + + if isinstance(logits, tuple): + logits, aux_logits = logits + aux_loss = self.loss(aux_logits, y) + else: + aux_loss = 0. + metrics = self.metrics(logits, y) + loss = self.loss(logits, y) + loss = loss + self.aux_weight * aux_loss + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + self.optimizer.step() + metrics["loss"] = loss.item() + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, + self.num_epochs, step, self.child_steps, meters) + + # Train sampler (mutator) + self.model.eval() + self.mutator.train() + meters = AverageMeterGroup() + for mutator_step in range(1, self.mutator_steps + 1): + self.mutator_optim.zero_grad() + for step in range(1, self.mutator_steps_aggregate + 1): + x, y = next(self.valid_loader) + x, y = to_device(x, self.device), to_device(y, self.device) + + self.mutator.reset() + with torch.no_grad(): + logits = self.model(x) + # self._write_graph_status() + metrics = self.metrics(logits, y) + reward = self.reward_function(logits, y) + if self.entropy_weight: + reward += self.entropy_weight * self.mutator.sample_entropy.item() + self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + loss = self.mutator.sample_log_prob * (reward - self.baseline) + if self.skip_weight: + loss += self.skip_weight * self.mutator.sample_skip_penalty + metrics["reward"] = reward + metrics["loss"] = loss.item() + metrics["ent"] = self.mutator.sample_entropy.item() + metrics["log_prob"] = self.mutator.sample_log_prob.item() + metrics["baseline"] = self.baseline + metrics["skip"] = self.mutator.sample_skip_penalty + + loss /= self.mutator_steps_aggregate + loss.backward() + meters.update(metrics) + + cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate + if self.log_frequency is not None and cur_step % self.log_frequency == 0: + logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, + mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, + meters) + + nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) + self.mutator_optim.step() + + def validate_one_epoch(self, epoch): + with torch.no_grad(): + accuracy = 0 + for arc_id in range(self.test_arc_per_epoch): + meters = AverageMeterGroup() + count, acc_this_round = 0,0 + for x, y in self.test_loader: + x, y = to_device(x, self.device), to_device(y, self.device) + self.mutator.reset() + child_model = self.export_child_model() + # self._generate_child_model(epoch, + # count, + # arc_id, + # child_model, + # self.child_model_path) + logits = self.model(x) + if isinstance(logits, tuple): + logits, _ = logits + metrics = self.metrics(logits, y) + loss = self.loss(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + count += 1 + acc_this_round += metrics['acc1'] + + logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", + epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, + meters.summary()) + acc_this_round /= count + accuracy += acc_this_round + # logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + print({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + with open(self.result_path, "a") as file: + file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", + "value": meters.get_last_acc()}}) + '\n') + # self.result['accuracy'].append(accuracy / self.test_arc_per_epoch) + + # export child_model + def export_child_model(self, selected_space=False): + if selected_space: + sampled = self.mutator.sample_final() + else: + sampled = self.mutator._cache + result = OrderedDict() + cur_layer_id = None + for mutable in self.mutator.mutables: + if not isinstance(mutable, (LayerChoice, InputChoice)): + cur_layer_id = mutable.key + # not supported as built-in + continue + choosed_ops_idx = self.mutator._convert_mutable_decision_to_human_readable(mutable, sampled[mutable.key]) + if not isinstance(choosed_ops_idx, list): + choosed_ops_idx = [choosed_ops_idx] + if isinstance(mutable, LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + choosed_ops = [str(mutable[idx]) for idx in choosed_ops_idx] + else: + + choosed_ops = choosed_ops_idx + if 'node' in cur_layer_id: + result[mutable.key] = choosed_ops + else: + result[cur_layer_id + '_' + mutable.key] = choosed_ops + + return result + + def _generate_child_model(self, + validation_epoch, + model_idx, + validation_step, + child_model, + file_path): + + # create child_models folder + # parent_path = os.path.join(file_path, 'child_model') + parent_path = file_path + if not os.path.exists(parent_path): + os.mkdir(parent_path) + + # create secondary directory + secondary_path = os.path.join(parent_path, 'validation_epoch_{}'.format(validation_epoch)) + if not os.path.exists(secondary_path): + os.mkdir(secondary_path) + + # create third directory + folder_path = os.path.join(secondary_path, 'validation_step_{}'.format(validation_step)) + if not os.path.exists(folder_path): + os.mkdir(folder_path) + + # save sampled child_model for validation + saved_path = os.path.join(folder_path, "child_model_%02d.json" % model_idx) + + with open(saved_path, "w") as ss_file: + json.dump(child_model, ss_file, indent=2) + +# save search space as search_space.json +def save_nas_search_space(mutator,file_path): + result = OrderedDict() + cur_layer_idx = None + for mutable in mutator.mutables.traverse(): + if not isinstance(mutable,(LayerChoice, InputChoice)): + cur_layer_idx = mutable.key + '_' + continue + # macro + if 'layer' in cur_layer_idx: + if isinstance(mutable, LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + result[cur_layer_idx + mutable.key] = 'op_list' + else: + result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, + 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', + 'choose_from': mutable.choose_from if mutable.choose_from else ''} + # micro + elif 'node' in cur_layer_idx: + if isinstance(mutable,LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + result[mutable.key] = 'op_list' + else: + result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, + 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', + 'choose_from': mutable.choose_from if mutable.choose_from else ''} + + dump_global_result(file_path,result) + +# def dump_global_result(args,global_result): +# with open(args['result_path'], "w") as ss_file: +# json.dump(global_result, ss_file, sort_keys=True, indent=2) + +def dump_global_result(res_path,global_result, sort_keys = False): + with open(res_path, "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) + + +if __name__ == "__main__": + parser = ArgumentParser("enas") + parser.add_argument( + "--data_dir", + type=str, + default="./data", + help="Directory containing the dataset and embedding file. (default: %(default)s)") + parser.add_argument("--model_selected_space_path", type=str, + default='./model_selected_space.json', help="sapce_path_out directory") + parser.add_argument("--result_path", type=str, + default='./model_result.json', help="res directory") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search_space directory") + parser.add_argument("--best_selected_space_path", type=str, + default='./model_selected_space.json', help="Best sapce_path_out directory of experiment") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument('--lr', type=float, default=0.005, metavar='N', + help='learning rate') + parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") + parser.add_argument("--batch_size", default=128, type=int) + parser.add_argument("--log_frequency", default=10, type=int) + parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") + args = parser.parse_args() + + mkdirs(args.result_path, args.search_space_path, args.best_selected_space_path) + # 设置随机种子 + torch.manual_seed(args.trial_id) + torch.cuda.manual_seed_all(args.trial_id) + np.random.seed(args.trial_id) + random.seed(args.trial_id) + # use deterministic instead of nondeterministic algorithm + # make sure exact results can be reproduced everytime. + torch.backends.cudnn.deterministic = True + + + dataset_train, dataset_valid = datasets.get_dataset("cifar10", args.data_dir) + if args.search_for == "macro": + model = GeneralNetwork() + num_epochs = args.epochs or 310 + mutator = None + mutator = EnasMutator(model) + elif args.search_for == "micro": + model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) + num_epochs = args.epochs or 150 + mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) + else: + raise AssertionError + + # 储存整个网络结构 + # args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) + # print( args.search_space_path, args.search_for ) + save_nas_search_space(mutator, args.search_space_path) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) + + trainer = EnasTrainer(model, + loss=criterion, + metrics=accuracy, + reward_function=reward_accuracy, + optimizer=optimizer, + callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./"+args.search_for+"_checkpoints")], + batch_size=args.batch_size, + num_epochs=num_epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + log_frequency=args.log_frequency, + mutator=mutator, + child_steps=2, + mutator_steps=2, + child_model_path='./'+args.search_for+'_child_model', + result_path=args.result_path) + + logger.info(trainer.metrics) + + t1 = time.time() + trainer.train() + # trainer.result["cost_time"] = time.time() - t1 + # dump_global_result(args.result_path,trainer.result) + + selected_model = trainer.export_child_model(selected_space = True) + dump_global_result(args.best_selected_space_path,selected_model) \ No newline at end of file diff --git a/dubhe-tadl/enas/utils.py b/dubhe-tadl/enas/utils.py new file mode 100644 index 0000000..f680db4 --- /dev/null +++ b/dubhe-tadl/enas/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +def reward_accuracy(output, target, topk=(1,)): + batch_size = target.size(0) + _, predicted = torch.max(output.data, 1) + return (predicted == target).sum().item() / batch_size diff --git a/dubhe-tadl/fixed.py b/dubhe-tadl/fixed.py new file mode 100644 index 0000000..cac7456 --- /dev/null +++ b/dubhe-tadl/fixed.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging + +from .mutables import InputChoice, LayerChoice, MutableScope +from .mutator import Mutator +from .utils import to_list + + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +class FixedArchitecture(Mutator): + """ + Fixed architecture mutator that always selects a certain graph. + + Parameters + ---------- + model : nn.Module + A mutable network. + fixed_arc : dict + Preloaded architecture object. + strict : bool + Force everything that appears in ``fixed_arc`` to be used at least once. + """ + + def __init__(self, model, fixed_arc, strict=True): + super().__init__(model) + self._fixed_arc = fixed_arc + + mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) + fixed_arc_keys = set(self._fixed_arc.keys()) + if fixed_arc_keys - mutable_keys: + raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) + if mutable_keys - fixed_arc_keys: + raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) + self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) + + def _from_human_readable_architecture(self, human_arc): + # convert from an exported architecture + result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. + # First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, + # which means {"op1": [0, ]} ir {"op1": ["conv", ]} + result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} + # Second, infer which ones are multi-hot arrays and which ones are in human-readable format. + # This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. + # Here, we assume an multihot array has to be a boolean array or a float array and matches the length. + for mutable in self.mutables: + if mutable.key not in result_arc: + continue # skip silently + choice_arr = result_arc[mutable.key] + if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): + if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ + (isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): + # multihot, do nothing + continue + if isinstance(mutable, LayerChoice): + choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] + choice_arr = [i in choice_arr for i in range(len(mutable))] + elif isinstance(mutable, InputChoice): + choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] + choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] + result_arc[mutable.key] = choice_arr + return result_arc + + def sample_search(self): + """ + Always returns the fixed architecture. + """ + return self._fixed_arc + + def sample_final(self): + """ + Always returns the fixed architecture. + """ + return self._fixed_arc + + def replace_layer_choice(self, module=None, prefix=""): + """ + Replace layer choices with selected candidates. It's done with best effort. + In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them. + If single choice, replace the module with a normal module. + + Parameters + ---------- + module : nn.Module + Module to be processed. + prefix : str + Module name under global namespace. + """ + if module is None: + module = self.model + for name, mutable in module.named_children(): + global_name = (prefix + "." if prefix else "") + name + if isinstance(mutable, LayerChoice): + chosen = self._fixed_arc[mutable.key] + if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: + # sum is one, max is one, there has to be an only one + # this is compatible with both integer arrays, boolean arrays and float arrays + _logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) + setattr(module, name, mutable[chosen.index(1)]) + else: + if mutable.return_mask: + _logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ + "LayerChoice will not be replaced.") + # remove unused parameters + for ch, n in zip(chosen, mutable.names): + if ch == 0 and not isinstance(ch, float): + setattr(mutable, n, None) + else: + self.replace_layer_choice(mutable, global_name) + + +def apply_fixed_architecture(model, fixed_arc): + """ + Load architecture from `fixed_arc` and apply to model. + + Parameters + ---------- + model : torch.nn.Module + Model with mutables. + fixed_arc : str or dict + Path to the JSON that stores the architecture, or dict that stores the exported architecture. + + Returns + ------- + FixedArchitecture + Mutator that is responsible for fixes the graph. + """ + + if isinstance(fixed_arc, str): + with open(fixed_arc) as f: + fixed_arc = json.load(f) + architecture = FixedArchitecture(model, fixed_arc) + architecture.reset() + + # for the convenience of parameters counting + architecture.replace_layer_choice() + return architecture diff --git a/dubhe-tadl/log.py b/dubhe-tadl/log.py new file mode 100644 index 0000000..c5f6110 --- /dev/null +++ b/dubhe-tadl/log.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +from datetime import datetime +from io import TextIOBase +import logging +from logging import FileHandler, Formatter, Handler, StreamHandler +from pathlib import Path +import sys +import time +from typing import Optional + +time_format = '%Y-%m-%d %H:%M:%S' + +formatter = Formatter( + '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s', + time_format +) + +def init_logger() -> None: + _setup_root_logger(StreamHandler(sys.stdout), logging.INFO) + + logging.basicConfig() + + +def _prepare_log_dir(path: Optional[str]) -> Path: + if path is None: + return Path() + ret = Path(path) + ret.mkdir(parents=True, exist_ok=True) + return ret + +def _setup_root_logger(handler: Handler, level: int) -> None: + _setup_logger('tadl', handler, level) + +def _setup_logger(name: str, handler: Handler, level: int) -> None: + handler.setFormatter(formatter) + logger = logging.getLogger(name) + logger.addHandler(handler) + logger.setLevel(level) + logger.propagate = False + +class _LogFileWrapper(TextIOBase): + # wrap the logger file so that anything written to it will automatically get formatted + + def __init__(self, log_file: TextIOBase): + self.file: TextIOBase = log_file + self.line_buffer: Optional[str] = None + self.line_start_time: Optional[datetime] = None + + def write(self, s: str) -> int: + cur_time = datetime.now() + if self.line_buffer and (cur_time - self.line_start_time).total_seconds() > 0.1: + self.flush() + + if self.line_buffer: + self.line_buffer += s + else: + self.line_buffer = s + self.line_start_time = cur_time + + if '\n' not in s: + return len(s) + + time_str = cur_time.strftime(time_format) + lines = self.line_buffer.split('\n') + for line in lines[:-1]: + self.file.write(f'[{time_str}] PRINT {line}\n') + self.file.flush() + + self.line_buffer = lines[-1] + self.line_start_time = cur_time + return len(s) + + def flush(self) -> None: + if self.line_buffer: + time_str = self.line_start_time.strftime(time_format) + self.file.write(f'[{time_str}] PRINT {self.line_buffer}\n') + self.file.flush() + self.line_buffer = None + diff --git a/dubhe-tadl/mutables.py b/dubhe-tadl/mutables.py new file mode 100644 index 0000000..331cf79 --- /dev/null +++ b/dubhe-tadl/mutables.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import warnings +from collections import OrderedDict + +import torch.nn as nn + +from .utils import global_mutable_counting + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Mutable(nn.Module): + """ + Mutable is designed to function as a normal layer, with all necessary operators' weights. + States and weights of architectures should be included in mutator, instead of the layer itself. + + Mutable has a key, which marks the identity of the mutable. This key can be used by users to share + decisions among different mutables. In mutator's implementation, mutators should use the key to + distinguish different mutables. Mutables that share the same key should be "similar" to each other. + + Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to + produce unique ids. + + Parameters + ---------- + key : str + The key of mutable. + + Notes + ----- + The counter is program level, but mutables are model level. In case multiple models are defined, and + you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually + instead of using automatic keys. + """ + + def __init__(self, key=None): + super().__init__() + if key is not None: + if not isinstance(key, str): + key = str(key) + logger.warning("Warning: key \"%s\" is not string, converted to string.", key) + self._key = key + else: + self._key = self.__class__.__name__ + str(global_mutable_counting()) + self.init_hook = self.forward_hook = None + + def __deepcopy__(self, memodict=None): + raise NotImplementedError("Deep copy doesn't work for mutables.") + + def __call__(self, *args, **kwargs): + self._check_built() + return super().__call__(*args, **kwargs) + + def set_mutator(self, mutator): + if "mutator" in self.__dict__: + raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? " + "Or did you apply multiple fixed architectures?") + self.__dict__["mutator"] = mutator + + @property + def key(self): + """ + Read-only property of key. + """ + return self._key + + @property + def name(self): + """ + After the search space is parsed, it will be the module name of the mutable. + """ + return self._name if hasattr(self, "_name") else "_key" + + @name.setter + def name(self, name): + self._name = name + + def _check_built(self): + if not hasattr(self, "mutator"): + raise ValueError( + "Mutator not set for {}. You might have forgotten to initialize and apply your mutator. " + "Or did you initialize a mutable on the fly in forward pass? Move to `__init__` " + "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) + + +class MutableScope(Mutable): + """ + Mutable scope marks a subgraph/submodule to help mutators make better decisions. + + If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might + need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will + look like "sub-search-space" in the scope. Scopes can be nested. + + There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization + and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after + the forward method of the class inheriting mutable scope. + + Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed + to appear in the dict of choices. + + Parameters + ---------- + key : str + Key of mutable scope. + """ + def __init__(self, key): + super().__init__(key=key) + + def __call__(self, *args, **kwargs): + try: + self._check_built() + self.mutator.enter_mutable_scope(self) + return super().__call__(*args, **kwargs) + finally: + self.mutator.exit_mutable_scope(self) + + +class LayerChoice(Mutable): + """ + Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results. + In rare cases, it can also select zero or many. + + Layer choice does not allow itself to be nested. + + Parameters + ---------- + op_candidates : list of nn.Module or OrderedDict + A module list to be selected from. + reduction : str + ``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected. + If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum. + ``concat`` concatenate the list at dimension 1. + return_mask : bool + If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. + key : str + Key of the input choice. + + Attributes + ---------- + length : int + Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended. + names : list of str + Names of candidates. + choices : list of Module + Deprecated. A list of all candidate modules in the layer choice module. + ``list(layer_choice)`` is recommended, which will serve the same purpose. + + Notes + ----- + ``op_candidates`` can be a list of modules or a ordered dict of named modules, for example, + + .. code-block:: python + + self.op_choice = LayerChoice(OrderedDict([ + ("conv3x3", nn.Conv2d(3, 16, 128)), + ("conv5x5", nn.Conv2d(5, 16, 128)), + ("conv7x7", nn.Conv2d(7, 16, 128)) + ])) + + Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or + ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. + """ + + def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): + super().__init__(key=key) + self.names = [] + if isinstance(op_candidates, OrderedDict): + for name, module in op_candidates.items(): + assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ + "Please don't use a reserved name '{}' for your module.".format(name) + self.add_module(name, module) + self.names.append(name) + elif isinstance(op_candidates, list): + for i, module in enumerate(op_candidates): + self.add_module(str(i), module) + self.names.append(str(i)) + else: + raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) + self.reduction = reduction + self.return_mask = return_mask + + def __getitem__(self, idx): + if isinstance(idx, str): + return self._modules[idx] + return list(self)[idx] + + def __setitem__(self, idx, module): + key = idx if isinstance(idx, str) else self.names[idx] + return setattr(self, key, module) + + def __delitem__(self, idx): + if isinstance(idx, slice): + for key in self.names[idx]: + delattr(self, key) + else: + if isinstance(idx, str): + key, idx = idx, self.names.index(idx) + else: + key = self.names[idx] + delattr(self, key) + del self.names[idx] + + @property + def length(self): + warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning) + return len(self) + + def __len__(self): + return len(self.names) + + def __iter__(self): + return map(lambda name: self._modules[name], self.names) + + @property + def choices(self): + warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning) + return list(self) + + def forward(self, *args, **kwargs): + """ + Returns + ------- + tuple of tensors + Output and selection mask. If ``return_mask`` is ``False``, only output is returned. + """ + out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs) + if self.return_mask: + return out, mask + return out + + +class InputChoice(Mutable): + """ + Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners, + use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to + know about ``choose_from``. + + The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. + The keys are designed to be the keys of the sources. To help mutators make better decisions, + mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the + output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., + ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a + module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed. + + In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2, + output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not + "actually" the direct output of cell1. + + .. code-block:: python + + class Cell(MutableScope): + pass + + class Net(nn.Module): + def __init__(self): + self.cell1 = Cell("cell1") + self.cell2 = Cell("cell2") + self.op = LayerChoice([conv3x3(), conv5x5()], key="op") + self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY]) + + def forward(self, x): + x1 = max_pooling(self.cell1(x)) + x2 = self.cell2(x) + x3 = self.op(x) + x4 = torch.zeros_like(x) + return self.input_choice([x1, x2, x3, x4]) + + Parameters + ---------- + n_candidates : int + Number of inputs to choose from. + choose_from : list of str + List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled. + If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates`` + number of empty string. + n_chosen : int + Recommended inputs to choose. If None, mutator is instructed to select any. + reduction : str + ``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`. + return_mask : bool + If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. + key : str + Key of the input choice. + """ + + NO_KEY = "" + + def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, + reduction="sum", return_mask=False, key=None): + super().__init__(key=key) + # precondition check + assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ + "must be not None." + if choose_from is not None and n_candidates is None: + n_candidates = len(choose_from) + elif choose_from is None and n_candidates is not None: + choose_from = [self.NO_KEY] * n_candidates + assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." + assert n_candidates > 0, "Number of candidates must be greater than 0." + assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ + "than number of candidates." + + self.n_candidates = n_candidates + self.choose_from = choose_from.copy() + self.n_chosen = n_chosen + self.reduction = reduction + self.return_mask = return_mask + + def forward(self, optional_inputs): + """ + Forward method of LayerChoice. + + Parameters + ---------- + optional_inputs : list or dict + Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of + ``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as + ``choose_from``. + + Returns + ------- + tuple of tensors + Output and selection mask. If ``return_mask`` is ``False``, only output is returned. + """ + optional_input_list = optional_inputs + if isinstance(optional_inputs, dict): + optional_input_list = [optional_inputs[tag] for tag in self.choose_from] + assert isinstance(optional_input_list, list), \ + "Optional input list must be a list, not a {}.".format(type(optional_input_list)) + assert len(optional_inputs) == self.n_candidates, \ + "Length of the input list must be equal to number of candidates." + out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) + if self.return_mask: + return out, mask + return out + diff --git a/dubhe-tadl/mutator.py b/dubhe-tadl/mutator.py new file mode 100644 index 0000000..e1ea507 --- /dev/null +++ b/dubhe-tadl/mutator.py @@ -0,0 +1,309 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from collections import defaultdict + +import numpy as np +import torch + +from .base_mutator import BaseMutator +from .mutables import LayerChoice, InputChoice +from .utils import to_list + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Mutator(BaseMutator): + + def __init__(self, model): + super().__init__(model) + self._cache = dict() + self._connect_all = False + + def sample_search(self): + """ + Override to implement this method to iterate over mutables and make decisions. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def sample_final(self): + """ + Override to implement this method to iterate over mutables and make decisions that is final + for export and retraining. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + raise NotImplementedError + + def reset(self): + """ + Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local + variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly. + """ + self._cache = self.sample_search() + + def export(self): + """ + Resample (for final) and return results. + + Returns + ------- + dict + A mapping from key of mutables to decisions. + """ + sampled = self.sample_final() + result = dict() + for mutable in self.mutables: + if not isinstance(mutable, (LayerChoice, InputChoice)): + # not supported as built-in + continue + result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key)) + if sampled: + raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys())) + return result + + def status(self): + """ + Return current selection status of mutator. + + Returns + ------- + dict + A mapping from key of mutables to decisions. All weights (boolean type and float type) + are converted into real number values. Numpy arrays and tensors are converted into list. + """ + data = dict() + for k, v in self._cache.items(): + if torch.is_tensor(v): + v = v.detach().cpu().numpy() + if isinstance(v, np.ndarray): + v = v.astype(np.float32).tolist() + data[k] = v + return data + + def graph(self, inputs): + """ + Return model supernet graph. + + Parameters + ---------- + inputs: tuple of tensor + Inputs that will be feeded into the network. + + Returns + ------- + dict + Containing ``node``, in Tensorboard GraphDef format. + Additional key ``mutable`` is a map from key to list of modules. + """ + if not torch.__version__.startswith("1.4"): + logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.") + from nni._graph_utils import build_graph + from google.protobuf import json_format + # protobuf should be installed as long as tensorboard is installed + try: + self._connect_all = True + graph_def, _ = build_graph(self.model, inputs, verbose=False) + result = json_format.MessageToDict(graph_def) + finally: + self._connect_all = False + + # `mutable` is to map the keys to a list of corresponding modules. + # A key can be linked to multiple modules, use `dedup=False` to find them all. + result["mutable"] = defaultdict(list) + for mutable in self.mutables.traverse(deduplicate=False): + # A module will be represent in the format of + # [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}] + # which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend. + # This format is aligned with the scope name jit gives. + modules = mutable.name.split(".") + path = [ + {"type": self.model.__class__.__name__, "name": ""} + ] + m = self.model + for module in modules: + m = getattr(m, module) + path.append({ + "type": m.__class__.__name__, + "name": module + }) + result["mutable"][mutable.key].append(path) + return result + + def on_forward_layer_choice(self, mutable, *args, **kwargs): + """ + On default, this method retrieves the decision obtained previously, and select certain operations. + Only operations with non-zero weight will be executed. The results will be added to a list. + Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. + + Parameters + ---------- + mutable : LayerChoice + Layer choice module. + args : list of torch.Tensor + Inputs + kwargs : dict + Inputs + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + Output and mask. + """ + if self._connect_all: + return self._all_connect_tensor_reduction(mutable.reduction, + [op(*args, **kwargs) for op in mutable]), \ + torch.ones(len(mutable)).bool() + + def _map_fn(op, args, kwargs): + return op(*args, **kwargs) + + mask = self._get_decision(mutable) + assert len(mask) == len(mutable), \ + "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) + out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def on_forward_input_choice(self, mutable, tensor_list): + """ + On default, this method retrieves the decision obtained previously, and select certain tensors. + Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. + + Parameters + ---------- + mutable : InputChoice + Input choice module. + tensor_list : list of torch.Tensor + Tensor list to apply the decision on. + + Returns + ------- + tuple of torch.Tensor and torch.Tensor + Output and mask. + """ + if self._connect_all: + return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \ + torch.ones(mutable.n_candidates).bool() + mask = self._get_decision(mutable) + assert len(mask) == mutable.n_candidates, \ + "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) + out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) + return self._tensor_reduction(mutable.reduction, out), mask + + def _select_with_mask(self, map_fn, candidates, mask): + """ + Select masked tensors and return a list of tensors. + + Parameters + ---------- + map_fn : function + Convert candidates to target candidates. Can be simply identity. + candidates : list of torch.Tensor + Tensor list to apply the decision on. + mask : list-like object + Can be a list, an numpy array or a tensor (recommended). Needs to + have the same length as ``candidates``. + + Returns + ------- + tuple of list of torch.Tensor and torch.Tensor + Output and mask. + """ + if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \ + (isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \ + "BoolTensor" in mask.type(): + out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] + elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ + (isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \ + "FloatTensor" in mask.type(): + out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] + else: + raise ValueError("Unrecognized mask '%s'" % mask) + if not torch.is_tensor(mask): + mask = torch.tensor(mask) # pylint: disable=not-callable + return out, mask + + def _tensor_reduction(self, reduction_type, tensor_list): + if reduction_type == "none": + return tensor_list + if not tensor_list: + return None # empty. return None for now + if len(tensor_list) == 1: + return tensor_list[0] + if reduction_type == "sum": + return sum(tensor_list) + if reduction_type == "mean": + return sum(tensor_list) / len(tensor_list) + if reduction_type == "concat": + return torch.cat(tensor_list, dim=1) + raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) + + def _all_connect_tensor_reduction(self, reduction_type, tensor_list): + if reduction_type == "none": + return tensor_list + if reduction_type == "concat": + return torch.cat(tensor_list, dim=1) + return torch.stack(tensor_list).sum(0) + + def _get_decision(self, mutable): + """ + By default, this method checks whether `mutable.key` is already in the decision cache, + and returns the result without double-check. + + Parameters + ---------- + mutable : Mutable + + Returns + ------- + object + """ + if mutable.key not in self._cache: + raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) + result = self._cache[mutable.key] + logger.debug("Decision %s: %s", mutable.key, result) + return result + + def _convert_mutable_decision_to_human_readable(self, mutable, sampled): + # Assert the existence of mutable.key in returned architecture. + # Also check if there is anything extra. + multihot_list = to_list(sampled) + converted = None + # If it's a boolean array, we can do optimization. + if all([t == 0 or t == 1 for t in multihot_list]): + if isinstance(mutable, LayerChoice): + assert len(multihot_list) == len(mutable), \ + "Results returned from 'sample_final()' (%s: %s) either too short or too long." \ + % (mutable.key, multihot_list) + # check if all modules have different names and they indeed have names + if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names): + converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]] + else: + converted = [i for i in range(len(multihot_list)) if multihot_list[i]] + if isinstance(mutable, InputChoice): + assert len(multihot_list) == mutable.n_candidates, \ + "Results returned from 'sample_final()' (%s: %s) either too short or too long." \ + % (mutable.key, multihot_list) + # check if all input candidates have different names + if len(set(mutable.choose_from)) == mutable.n_candidates: + converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]] + else: + converted = [i for i in range(len(multihot_list)) if multihot_list[i]] + if converted is not None: + # if only one element, then remove the bracket + if len(converted) == 1: + converted = converted[0] + else: + # do nothing + converted = multihot_list + return converted diff --git a/dubhe-tadl/network_morphism/README.md b/dubhe-tadl/network_morphism/README.md new file mode 100644 index 0000000..b35c546 --- /dev/null +++ b/dubhe-tadl/network_morphism/README.md @@ -0,0 +1,47 @@ +# Network Morphism +The implementation of the Network Morphism algorithm is based on +[Auto-Keras: An Efficient Neural Architecture Search System](https://arxiv.org/pdf/1806.10282.pdf) + +Train stage +``` +python network_morphism_train.py +--trial_id 0 +--experiment_dir 'tadl' +--log_path 'tadl/train/0/log' +--data_dir '../data/' +--result_path 'trial_id/result.json' +--log_path 'trial_id/log' +--search_space_path 'experiment_id/search_space.json' +--best_selected_space_path 'experiment_id/best_selected_space.json' +--lr 0.001 --epochs 100 --batch_size 32 --opt 'SGD' +``` + +select stage +``` +python network_morphism_select.py +``` + +retrain stage +``` +python network_morphism_retrain.py +--data_dir '../data/' +--experiment_dir 'tadl' +--result_path 'trial_id/result.json' +--log_path 'trial_id/log' +--best_selected_space_path 'experiment_id/best_selected_space.json' +--best_checkpoint_dir 'experiment_id/' +--trial_id 0 --batch_size 32 --opt 'SGD' --epochs 100 --lr 0.001 + + +``` + +The best model searched achieved 88.1% on CIFAR-10 dataset after 100 trials. + +Dependencies: +``` +Python = 3.6.13 +pytorch = 1.8.0 +torchvision = 0.9.0 +scipy = 1.5.2 +scikit-learn = 0.24.1 +``` \ No newline at end of file diff --git a/dubhe-tadl/network_morphism/algorithm/bayesian.py b/dubhe-tadl/network_morphism/algorithm/bayesian.py new file mode 100644 index 0000000..e6ac257 --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/bayesian.py @@ -0,0 +1,517 @@ + +import math +import random +from copy import deepcopy +from functools import total_ordering +from queue import PriorityQueue + +import numpy as np +from scipy.linalg import LinAlgError, cho_solve, cholesky, solve_triangular +from scipy.optimize import linear_sum_assignment +from sklearn.metrics.pairwise import rbf_kernel + +from .graph_transformer import transform +from .layers import is_layer +from utils import Constant, OptimizeMode +import logging + +logger = logging.getLogger(__name__) + +# equation(6) dl +def layer_distance(a, b): + """The distance between two layers.""" + # pylint: disable=unidiomatic-typecheck + if not isinstance(a, type(b)): + return 1.0 + if is_layer(a, "Conv"): + att_diff = [ + (a.filters, b.filters), + (a.kernel_size, b.kernel_size), + (a.stride, b.stride), + ] + return attribute_difference(att_diff) + if is_layer(a, "Pooling"): + att_diff = [ + (a.padding, b.padding), + (a.kernel_size, b.kernel_size), + (a.stride, b.stride), + ] + return attribute_difference(att_diff) + return 0.0 + +# equation(6) +def attribute_difference(att_diff): + ''' The attribute distance. + ''' + + ret = 0 + for a_value, b_value in att_diff: + if max(a_value, b_value) == 0: + ret += 0 + else: + ret += abs(a_value - b_value) * 1.0 / max(a_value, b_value) + return ret * 1.0 / len(att_diff) + +# equation(7) A +def layers_distance(list_a, list_b): + """The distance between the layers of two neural networks.""" + len_a = len(list_a) + len_b = len(list_b) + f = np.zeros((len_a + 1, len_b + 1)) + f[-1][-1] = 0 + for i in range(-1, len_a): + f[i][-1] = i + 1 + for j in range(-1, len_b): + f[-1][j] = j + 1 + for i in range(len_a): + for j in range(len_b): + f[i][j] = min( + f[i][j - 1] + 1, + f[i - 1][j] + 1, + f[i - 1][j - 1] + layer_distance(list_a[i], list_b[j]), + ) + return f[len_a - 1][len_b - 1] + +# equation (9) ds +# 0: topo rank of the start, 1: rank of the end +def skip_connection_distance(a, b): + """The distance between two skip-connections.""" + if a[2] != b[2]: + return 1.0 + len_a = abs(a[1] - a[0]) + len_b = abs(b[1] - b[0]) + return (abs(a[0] - b[0]) + abs(len_a - len_b)) / \ + (max(a[0], b[0]) + max(len_a, len_b)) + +# equation (8) Ds +# convert equation (8) minimization part into a bipartite graph matching problem and solved by hungarian algorithm(linear_sum_assignment) +def skip_connections_distance(list_a, list_b): + """The distance between the skip-connections of two neural networks.""" + distance_matrix = np.zeros((len(list_a), len(list_b))) + for i, a in enumerate(list_a): + for j, b in enumerate(list_b): + distance_matrix[i][j] = skip_connection_distance(a, b) + return distance_matrix[linear_sum_assignment(distance_matrix)].sum() + abs( + len(list_a) - len(list_b) + ) + +# equation (4) +def edit_distance(x, y): + """The distance between two neural networks. + Args: + x: An instance of NetworkDescriptor. + y: An instance of NetworkDescriptor + Returns: + The edit-distance between x and y. + """ + + ret = layers_distance(x.layers, y.layers) + ret += Constant.KERNEL_LAMBDA * skip_connections_distance( + x.skip_connections, y.skip_connections + ) + return ret + + +class IncrementalGaussianProcess: + """Gaussian process regressor. + Attributes: + alpha: A hyperparameter. + """ + + def __init__(self): + self.alpha = 1e-10 + self._distance_matrix = None + self._x = None + self._y = None + self._first_fitted = False + self._l_matrix = None + self._alpha_vector = None + + @property + def kernel_matrix(self): + ''' Kernel matric. + ''' + return self._distance_matrix + + def fit(self, train_x, train_y): + """ Fit the regressor with more data. + Args: + train_x: A list of NetworkDescriptor. + train_y: A list of metric values. + """ + if self.first_fitted: + self.incremental_fit(train_x, train_y) + else: + self.first_fit(train_x, train_y) + + # compute the kernel matrix k, alpha_vector + # 和first fit区别就是需要加入新的训练样本扩充distance matrix + def incremental_fit(self, train_x, train_y): + """ Incrementally fit the regressor. """ + if not self._first_fitted: + raise ValueError( + "The first_fit function needs to be called first.") + + train_x, train_y = np.array(train_x), np.array(train_y) + + # Incrementally compute K + up_right_k = edit_distance_matrix(self._x, train_x) + down_left_k = np.transpose(up_right_k) + down_right_k = edit_distance_matrix(train_x) + up_k = np.concatenate((self._distance_matrix, up_right_k), axis=1) + down_k = np.concatenate((down_left_k, down_right_k), axis=1) + temp_distance_matrix = np.concatenate((up_k, down_k), axis=0) + + k_matrix = bourgain_embedding_matrix(temp_distance_matrix) + + diagonal = np.diag_indices_from(k_matrix) + diagonal = (diagonal[0][-len(train_x):], diagonal[1][-len(train_x):]) + k_matrix[diagonal] += self.alpha + try: + self._l_matrix = cholesky(k_matrix, lower=True) # Line 2 + except LinAlgError as err: + logger.error('LinAlgError') + return self + + self._x = np.concatenate((self._x, train_x), axis=0) + self._y = np.concatenate((self._y, train_y), axis=0) + self._distance_matrix = temp_distance_matrix + self._alpha_vector = cho_solve( + (self._l_matrix, True), self._y) # Line 3 + + return self + + @property + def first_fitted(self): + ''' if it is firsr fitted + ''' + return self._first_fitted + + # update过程,第一次fit。 + def first_fit(self, train_x, train_y): + """ Fit the regressor for the first time. """ + train_x, train_y = np.array(train_x), np.array(train_y) + + self._x = np.copy(train_x) + self._y = np.copy(train_y) + + self._distance_matrix = edit_distance_matrix(self._x) + k_matrix = bourgain_embedding_matrix(self._distance_matrix) + k_matrix[np.diag_indices_from(k_matrix)] += self.alpha + + self._l_matrix = cholesky(k_matrix, lower=True) # Line 2 + + # cho_solve Ax = b return x = A^{-1}b + self._alpha_vector = cho_solve( + (self._l_matrix, True), self._y) # Line 3 + + self._first_fitted = True + return self + + # 获得 predictive distribution 的 mean & std + def predict(self, train_x): + """Predict the result. + Args: + train_x: A list of NetworkDescriptor. + Returns: + y_mean: The predicted mean. + y_std: The predicted standard deviation. + """ + k_trans = np.exp(-np.power(edit_distance_matrix(train_x, self._x), 2)) + y_mean = k_trans.dot(self._alpha_vector) # Line 4 (y_mean = f_star) + + # compute inverse K_inv of K based on its Cholesky + # decomposition L and its inverse L_inv + l_inv = solve_triangular( + self._l_matrix.T, np.eye( + self._l_matrix.shape[0])) + k_inv = l_inv.dot(l_inv.T) + # Compute variance of predictive distribution + y_var = np.ones(len(train_x), dtype=np.float) + y_var -= np.einsum("ij,ij->i", np.dot(k_trans, k_inv), k_trans) + + # Check if any of the variances is negative because of + # numerical issues. If yes: set the variance to 0. + y_var_negative = y_var < 0 + if np.any(y_var_negative): + y_var[y_var_negative] = 0.0 + return y_mean, np.sqrt(y_var) + + +def edit_distance_matrix(train_x, train_y=None): + """Calculate the edit distance. + Args: + train_x: A list of neural architectures. + train_y: A list of neural architectures. + Returns: + An edit-distance matrix. + """ + if train_y is None: + ret = np.zeros((train_x.shape[0], train_x.shape[0])) + for x_index, x in enumerate(train_x): + for y_index, y in enumerate(train_x): + if x_index == y_index: + ret[x_index][y_index] = 0 + elif x_index < y_index: + ret[x_index][y_index] = edit_distance(x, y) + else: + ret[x_index][y_index] = ret[y_index][x_index] + return ret + ret = np.zeros((train_x.shape[0], train_y.shape[0])) + for x_index, x in enumerate(train_x): + for y_index, y in enumerate(train_y): + ret[x_index][y_index] = edit_distance(x, y) + return ret + + +def vector_distance(a, b): + """The Euclidean distance between two vectors.""" + a = np.array(a) + b = np.array(b) + return np.linalg.norm(a - b) + +# 从edit-distance矩阵空间到欧几里得空间的映射 +def bourgain_embedding_matrix(distance_matrix): + """Use Bourgain algorithm to embed the neural architectures based on their edit-distance. + Args: + distance_matrix: A matrix of edit-distances. + Returns: + A matrix of distances after embedding. + """ + distance_matrix = np.array(distance_matrix) + n = len(distance_matrix) + if n == 1: + return distance_matrix + np.random.seed(123) + distort_elements = [] + r = range(n) + k = int(math.ceil(math.log(n) / math.log(2) - 1)) + t = int(math.ceil(math.log(n))) + counter = 0 + for i in range(0, k + 1): + for t in range(t): + s = np.random.choice(r, 2 ** i) + for j in r: + d = min([distance_matrix[j][s] for s in s]) + counter += len(s) + if i == 0 and t == 0: + distort_elements.append([d]) + else: + distort_elements[j].append(d) + return rbf_kernel(distort_elements, distort_elements) + + +class BayesianOptimizer: + """ A Bayesian optimizer for neural architectures. + Attributes: + searcher: The Searcher which is calling the Bayesian optimizer. + t_min: The minimum temperature for simulated annealing. + metric: An instance of the Metric subclasses. + gpr: A GaussianProcessRegressor for bayesian optimization. + beta: The beta in acquisition function. (refer to our paper) + search_tree: The network morphism search tree. + """ + + def __init__(self, searcher, t_min, optimizemode, beta=None): + self.searcher = searcher + self.t_min = t_min + self.optimizemode = optimizemode + self.gpr = IncrementalGaussianProcess() + self.beta = beta if beta is not None else Constant.BETA + self.search_tree = SearchTree() + + def fit(self, x_queue, y_queue): + """ Fit the optimizer with new architectures and performances. + Args: + x_queue: A list of NetworkDescriptor. + y_queue: A list of metric values. + """ + self.gpr.fit(x_queue, y_queue) + + # Algorithm 1 + # optimize acquisition function + def generate(self, descriptors): + """Generate new architecture. + Args: + descriptors: All the searched neural architectures. (search history) + Returns: + graph: An instance of Graph. A morphed neural network with weights. + father_id: The father node ID in the search tree. + """ + model_ids = self.search_tree.adj_list.keys() + + target_graph = None + father_id = None + descriptors = deepcopy(descriptors) + elem_class = Elem + if self.optimizemode is OptimizeMode.Maximize: + elem_class = ReverseElem + + ''' + 1.初始化优先队列 + 2.优先队列里面元素为之前所有生成的模型 + ''' + pq = PriorityQueue() + temp_list = [] + for model_id in model_ids: + metric_value = self.searcher.get_metric_value_by_id(model_id) + temp_list.append((metric_value, model_id)) + temp_list = sorted(temp_list) + for metric_value, model_id in temp_list: + graph = self.searcher.load_model_by_id(model_id) + graph.clear_operation_history() + graph.clear_weights() + # 已经产生的模型father_id就是自己的id + pq.put(elem_class(metric_value, model_id, graph)) + + t = 1.0 + t_min = self.t_min + alpha = 0.9 + opt_acq = self._get_init_opt_acq_value() + num_iter = 0 + # logger.info('initial queue size ', pq.qsize()) + while not pq.empty() and t > t_min: + num_iter += 1 + elem = pq.get() + # logger.info("elem.metric_value:{}".format(elem.metric_value)) + # logger.info("opt_acq:{}".format(opt_acq)) + if self.optimizemode is OptimizeMode.Maximize: + temp_exp = min((elem.metric_value - opt_acq) / t, 1.0) + else: + temp_exp = min((opt_acq - elem.metric_value) / t, 1.0) + # logger.info("temp_exp this round ", temp_exp) + ap = math.exp(temp_exp) + # logger.info("ap this round ", ap) + if ap >= random.uniform(0, 1): + # line 9,10 in algorithm 1 + for temp_graph in transform(elem.graph): + # 已经出现过的网络不加入 + if contain(descriptors, temp_graph.extract_descriptor()): + continue + + #用acq作为贝叶斯模型给出的评价标准 + temp_acq_value = self.acq(temp_graph) + + # 这个优先队列会不断增长,就算transform出来的网络也会进入。 + pq.put( + # 记住这个模型是从哪个father生长出来的 + elem_class( + temp_acq_value, + elem.father_id, + temp_graph)) + # logger.info('temp_acq_value ', temp_acq_value) + # logger.info('queue size ', pq.qsize()) + descriptors.append(temp_graph.extract_descriptor()) + # 选一个最好的当父 + if self._accept_new_acq_value(opt_acq, temp_acq_value): + opt_acq = temp_acq_value + father_id = elem.father_id + target_graph = deepcopy(temp_graph) + t *= alpha + # logger.info('number of iter in this search {}'.format(num_iter)) + # Did not found a not duplicated architecture + if father_id is None: + return None, None + nm_graph = self.searcher.load_model_by_id(father_id) + # 从当前父graph开始,根据target_graph中的operation_history,一步步从当前父网络操作到target_graph + # 因为在存入pq时进行了clear_operation_history()操作。等于target_graph中只存了从当前父网络到target_graph的操作 + # 而nm_graph中的operation_history保存完整的,到基类的history + for args in target_graph.operation_history: + getattr(nm_graph, args[0])(*list(args[1:])) + # target space + return nm_graph, father_id + + # equation (10) + def acq(self, graph): + ''' estimate the value of generated graph + ''' + mean, std = self.gpr.predict(np.array([graph.extract_descriptor()])) + if self.optimizemode is OptimizeMode.Maximize: + return mean + self.beta * std + return mean - self.beta * std + + def _get_init_opt_acq_value(self): + if self.optimizemode is OptimizeMode.Maximize: + return -np.inf + return np.inf + + def _accept_new_acq_value(self, opt_acq, temp_acq_value): + if temp_acq_value > opt_acq and self.optimizemode is OptimizeMode.Maximize: + return True + if temp_acq_value < opt_acq and not self.optimizemode is OptimizeMode.Maximize: + return True + return False + + def add_child(self, father_id, model_id): + ''' add child to the search tree + Arguments: + father_id {int} -- father id + model_id {int} -- model id + ''' + + self.search_tree.add_child(father_id, model_id) + + +@total_ordering +class Elem: + """Elements to be sorted according to metric value.""" + + def __init__(self, metric_value, father_id, graph): + self.father_id = father_id + self.graph = graph + self.metric_value = metric_value + + def __eq__(self, other): + return self.metric_value == other.metric_value + + def __lt__(self, other): + return self.metric_value < other.metric_value + + +class ReverseElem(Elem): + """Elements to be reversely sorted according to metric value.""" + + def __lt__(self, other): + return self.metric_value > other.metric_value + + +def contain(descriptors, target_descriptor): + """Check if the target descriptor is in the descriptors.""" + for descriptor in descriptors: + if edit_distance(descriptor, target_descriptor) < 1e-5: + return True + return False + + +class SearchTree: + """The network morphism search tree.""" + + def __init__(self): + self.root = None + self.adj_list = {} + + def add_child(self, u, v): + ''' add child to search tree itself. + Arguments: + u {int} -- father id + v {int} -- child id + ''' + + if u == -1: + self.root = v + self.adj_list[v] = [] + return + if v not in self.adj_list[u]: + self.adj_list[u].append(v) + if v not in self.adj_list: + self.adj_list[v] = [] + + def get_dict(self, u=None): + """ A recursive function to return the content of the tree in a dict.""" + if u is None: + return self.get_dict(self.root) + children = [] + for v in self.adj_list[u]: + children.append(self.get_dict(v)) + ret = {"name": u, "children": children} + return ret diff --git a/dubhe-tadl/network_morphism/algorithm/graph.py b/dubhe-tadl/network_morphism/algorithm/graph.py new file mode 100644 index 0000000..89ac8bf --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/graph.py @@ -0,0 +1,919 @@ +import json +from collections.abc import Iterable +from copy import deepcopy, copy +from queue import Queue + +import numpy as np +import torch + +from .layer_transformer import ( + add_noise, + wider_bn, + wider_next_conv, + wider_next_dense, + wider_pre_conv, + wider_pre_dense, + init_dense_weight, + init_conv_weight, + init_bn_weight, +) +from .layers import ( + StubAdd, + StubConcatenate, + StubReLU, + get_batch_norm_class, + get_conv_class, + is_layer, + layer_width, + set_stub_weight_to_torch, + set_torch_weight_to_stub, + layer_description_extractor, + layer_description_builder, +) +from utils import Constant + + +class NetworkDescriptor: + """A class describing the neural architecture for neural network kernel. + It only record the width of convolutional and dense layers, and the skip-connection types and positions. + """ + + CONCAT_CONNECT = "concat" + ADD_CONNECT = "add" + + def __init__(self): + self.skip_connections = [] + self.layers = [] + + @property + def n_layers(self): + return len(self.layers) + + def add_skip_connection(self, u, v, connection_type): + """ Add a skip-connection to the descriptor. + Args: + u: Number of convolutional layers before the starting point. + v: Number of convolutional layers before the ending point. + connection_type: Must be either CONCAT_CONNECT or ADD_CONNECT. + """ + if connection_type not in [self.CONCAT_CONNECT, self.ADD_CONNECT]: + raise ValueError( + "connection_type should be NetworkDescriptor.CONCAT_CONNECT " + "or NetworkDescriptor.ADD_CONNECT." + ) + + self.skip_connections.append((u, v, connection_type)) + + def to_json(self): + ''' NetworkDescriptor to json representation + ''' + + skip_list = [] + for u, v, connection_type in self.skip_connections: + skip_list.append({"from": u, "to": v, "type": connection_type}) + return {"node_list": self.layers, "skip_list": skip_list} + + def add_layer(self, layer): + ''' add one layer + ''' + self.layers.append(layer) + + +class Node: + """A class for intermediate output tensor (node) in the Graph. + Attributes: + shape: A tuple describing the shape of the tensor. + """ + + def __init__(self, shape): + self.shape = shape + + +class Graph: + """A class representing the neural architecture graph of a model. + Graph extracts the neural architecture graph from a model. + Each node in the graph is a intermediate tensor between layers. + Each layer is an edge in the graph. + Notably, multiple edges may refer to the same layer. + (e.g. Add layer is adding two tensor into one tensor. So it is related to two edges.) + Attributes: + weighted: A boolean of whether the weights and biases in the neural network + should be included in the graph. + input_shape: A tuple of integers, which does not include the batch axis. + node_list: A list of integers. The indices of the list are the identifiers. + layer_list: A list of stub layers. The indices of the list are the identifiers. + node_to_id: A dict instance mapping from node integers to their identifiers. + layer_to_id: A dict instance mapping from stub layers to their identifiers. + layer_id_to_input_node_ids: A dict instance mapping from layer identifiers + to their input nodes identifiers. + layer_id_to_output_node_ids: A dict instance mapping from layer identifiers + to their output nodes identifiers. + adj_list: A two dimensional list. The adjacency list of the graph. The first dimension is + identified by tensor identifiers. In each edge list, the elements are two-element tuples + of (tensor identifier, layer identifier). + reverse_adj_list: A reverse adjacent list in the same format as adj_list. + operation_history: A list saving all the network morphism operations. + vis: A dictionary of temporary storage for whether an local operation has been done + during the network morphism. + """ + + def __init__(self, input_shape, weighted=True): + """Initializer for Graph. + """ + self.input_shape = input_shape + self.weighted = weighted + self.node_list = [] + self.layer_list = [] + # node id start with 0 + self.node_to_id = {} + self.layer_to_id = {} + self.layer_id_to_input_node_ids = {} + self.layer_id_to_output_node_ids = {} + self.adj_list = {} + self.reverse_adj_list = {} + self.operation_history = [] + self.n_dim = len(input_shape) - 1 + self.conv = get_conv_class(self.n_dim) + self.batch_norm = get_batch_norm_class(self.n_dim) + + self.vis = None + self._add_node(Node(input_shape)) + + def add_layer(self, layer, input_node_id): + """Add a layer to the Graph. + Args: + layer: An instance of the subclasses of StubLayer in layers.py. + input_node_id: An integer. The ID of the input node of the layer. + Returns: + output_node_id: An integer. The ID of the output node of the layer. + """ + if isinstance(input_node_id, Iterable): + layer.input = list(map(lambda x: self.node_list[x], input_node_id)) + output_node_id = self._add_node(Node(layer.output_shape)) + for node_id in input_node_id: + self._add_edge(layer, node_id, output_node_id) + + else: + layer.input = self.node_list[input_node_id] + output_node_id = self._add_node(Node(layer.output_shape)) + self._add_edge(layer, input_node_id, output_node_id) + + layer.output = self.node_list[output_node_id] + return output_node_id + + def clear_operation_history(self): + self.operation_history = [] + + @property + def n_nodes(self): + """Return the number of nodes in the model.""" + return len(self.node_list) + + @property + def n_layers(self): + """Return the number of layers in the model.""" + return len(self.layer_list) + + def _add_node(self, node): + """Add a new node to node_list and give the node an ID. + Args: + node: An instance of Node. + Returns: + node_id: An integer. + """ + node_id = len(self.node_list) + self.node_to_id[node] = node_id + self.node_list.append(node) + self.adj_list[node_id] = [] + self.reverse_adj_list[node_id] = [] + return node_id + + def _add_edge(self, layer, input_id, output_id): + """Add a new layer to the graph. The nodes should be created in advance.""" + + if layer in self.layer_to_id: + layer_id = self.layer_to_id[layer] + if input_id not in self.layer_id_to_input_node_ids[layer_id]: + self.layer_id_to_input_node_ids[layer_id].append(input_id) + if output_id not in self.layer_id_to_output_node_ids[layer_id]: + self.layer_id_to_output_node_ids[layer_id].append(output_id) + else: + layer_id = len(self.layer_list) + self.layer_list.append(layer) + self.layer_to_id[layer] = layer_id + self.layer_id_to_input_node_ids[layer_id] = [input_id] + self.layer_id_to_output_node_ids[layer_id] = [output_id] + + self.adj_list[input_id].append((output_id, layer_id)) + self.reverse_adj_list[output_id].append((input_id, layer_id)) + + def _redirect_edge(self, u_id, v_id, new_v_id): + """Redirect the layer to a new node. + Change the edge originally from `u_id` to `v_id` into an edge from `u_id` to `new_v_id` + while keeping all other property of the edge the same. + """ + layer_id = None + for index, edge_tuple in enumerate(self.adj_list[u_id]): + if edge_tuple[0] == v_id: + layer_id = edge_tuple[1] + self.adj_list[u_id][index] = (new_v_id, layer_id) + self.layer_list[layer_id].output = self.node_list[new_v_id] + break + + for index, edge_tuple in enumerate(self.reverse_adj_list[v_id]): + if edge_tuple[0] == u_id: + layer_id = edge_tuple[1] + self.reverse_adj_list[v_id].remove(edge_tuple) + break + self.reverse_adj_list[new_v_id].append((u_id, layer_id)) + for index, value in enumerate( + self.layer_id_to_output_node_ids[layer_id]): + if value == v_id: + self.layer_id_to_output_node_ids[layer_id][index] = new_v_id + break + + def _replace_layer(self, layer_id, new_layer): + """Replace the layer with a new layer.""" + old_layer = self.layer_list[layer_id] + new_layer.input = old_layer.input + new_layer.output = old_layer.output + new_layer.output.shape = new_layer.output_shape + self.layer_list[layer_id] = new_layer + self.layer_to_id[new_layer] = layer_id + self.layer_to_id.pop(old_layer) + + @property + def topological_order(self): + """Return the topological order of the node IDs from the input node to the output node.""" + q = Queue() + in_degree = {} + for i in range(self.n_nodes): + in_degree[i] = 0 + for u in range(self.n_nodes): + for v, _ in self.adj_list[u]: + in_degree[v] += 1 + for i in range(self.n_nodes): + if in_degree[i] == 0: + q.put(i) + + order_list = [] + while not q.empty(): + u = q.get() + order_list.append(u) + for v, _ in self.adj_list[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: + q.put(v) + return order_list + + def _get_pooling_layers(self, start_node_id, end_node_id): + """ + Given two node IDs, return all the pooling layers between them. + Conv layer with strid > 1 is also considered as a Pooling layer. + """ + + layer_list = [] + node_list = [start_node_id] + assert self._depth_first_search(end_node_id, layer_list, node_list) + ret = [] + for layer_id in layer_list: + layer = self.layer_list[layer_id] + if is_layer(layer, "Pooling"): + ret.append(layer) + elif is_layer(layer, "Conv") and layer.stride != 1: + ret.append(layer) + return ret + + def _depth_first_search(self, target_id, layer_id_list, node_list): + """Search for all the layers and nodes down the path. + A recursive function to search all the layers and nodes between the node in the node_list + and the node with target_id.""" + assert len(node_list) <= self.n_nodes + u = node_list[-1] + if u == target_id: + return True + + for v, layer_id in self.adj_list[u]: + layer_id_list.append(layer_id) + node_list.append(v) + if self._depth_first_search(target_id, layer_id_list, node_list): + return True + layer_id_list.pop() + node_list.pop() + + return False + + def _search(self, u, start_dim, total_dim, n_add): + """Search the graph for all the layers to be widened caused by an operation. + It is an recursive function with duplication check to avoid deadlock. + It searches from a starting node u until the corresponding layers has been widened. + Args: + u: The starting node ID. + start_dim: The position to insert the additional dimensions. + total_dim: The total number of dimensions the layer has before widening. + n_add: The number of dimensions to add. + """ + if (u, start_dim, total_dim, n_add) in self.vis: + return + self.vis[(u, start_dim, total_dim, n_add)] = True + for v, layer_id in self.adj_list[u]: + layer = self.layer_list[layer_id] + + if is_layer(layer, "Conv"): + new_layer = wider_next_conv( + layer, start_dim, total_dim, n_add, self.weighted + ) + self._replace_layer(layer_id, new_layer) + + elif is_layer(layer, "Dense"): + new_layer = wider_next_dense( + layer, start_dim, total_dim, n_add, self.weighted + ) + self._replace_layer(layer_id, new_layer) + + elif is_layer(layer, "BatchNormalization"): + new_layer = wider_bn( + layer, start_dim, total_dim, n_add, self.weighted) + self._replace_layer(layer_id, new_layer) + self._search(v, start_dim, total_dim, n_add) + + elif is_layer(layer, "Concatenate"): + if self.layer_id_to_input_node_ids[layer_id][1] == u: + # u is on the right of the concat + # next_start_dim += next_total_dim - total_dim + left_dim = self._upper_layer_width( + self.layer_id_to_input_node_ids[layer_id][0] + ) + next_start_dim = start_dim + left_dim + next_total_dim = total_dim + left_dim + else: + next_start_dim = start_dim + next_total_dim = total_dim + self._upper_layer_width( + self.layer_id_to_input_node_ids[layer_id][1] + ) + self._search(v, next_start_dim, next_total_dim, n_add) + + else: + self._search(v, start_dim, total_dim, n_add) + + for v, layer_id in self.reverse_adj_list[u]: + layer = self.layer_list[layer_id] + if is_layer(layer, "Conv"): + new_layer = wider_pre_conv(layer, n_add, self.weighted) + self._replace_layer(layer_id, new_layer) + elif is_layer(layer, "Dense"): + new_layer = wider_pre_dense(layer, n_add, self.weighted) + self._replace_layer(layer_id, new_layer) + elif is_layer(layer, "Concatenate"): + continue + else: + self._search(v, start_dim, total_dim, n_add) + + def _upper_layer_width(self, u): + for v, layer_id in self.reverse_adj_list[u]: + layer = self.layer_list[layer_id] + if is_layer(layer, "Conv") or is_layer(layer, "Dense"): + return layer_width(layer) + elif is_layer(layer, "Concatenate"): + a = self.layer_id_to_input_node_ids[layer_id][0] + b = self.layer_id_to_input_node_ids[layer_id][1] + return self._upper_layer_width(a) + self._upper_layer_width(b) + else: + return self._upper_layer_width(v) + return self.node_list[0].shape[-1] + + def to_deeper_model(self, target_id, new_layer): + """Insert a relu-conv-bn block after the target block. + Args: + target_id: A convolutional layer ID. The new block should be inserted after the block. + new_layer: An instance of StubLayer subclasses. + """ + self.operation_history.append( + ("to_deeper_model", target_id, new_layer)) + input_id = self.layer_id_to_input_node_ids[target_id][0] + output_id = self.layer_id_to_output_node_ids[target_id][0] + if self.weighted: + if is_layer(new_layer, "Dense"): + init_dense_weight(new_layer) + elif is_layer(new_layer, "Conv"): + init_conv_weight(new_layer) + elif is_layer(new_layer, "BatchNormalization"): + init_bn_weight(new_layer) + + self._insert_new_layers([new_layer], input_id, output_id) + + def to_wider_model(self, pre_layer_id, n_add): + """Widen the last dimension of the output of the pre_layer. + Args: + pre_layer_id: The ID of a convolutional layer or dense layer. + n_add: The number of dimensions to add. + """ + self.operation_history.append(("to_wider_model", pre_layer_id, n_add)) + pre_layer = self.layer_list[pre_layer_id] + output_id = self.layer_id_to_output_node_ids[pre_layer_id][0] + dim = layer_width(pre_layer) + self.vis = {} + self._search(output_id, dim, dim, n_add) + # Update the tensor shapes. + for u in self.topological_order: + for v, layer_id in self.adj_list[u]: + self.node_list[v].shape = self.layer_list[layer_id].output_shape + + def _insert_new_layers(self, new_layers, start_node_id, end_node_id): + """Insert the new_layers after the node with start_node_id.""" + new_node_id = self._add_node(deepcopy(self.node_list[end_node_id])) + temp_output_id = new_node_id + for layer in new_layers[:-1]: + temp_output_id = self.add_layer(layer, temp_output_id) + + self._add_edge(new_layers[-1], temp_output_id, end_node_id) + new_layers[-1].input = self.node_list[temp_output_id] + new_layers[-1].output = self.node_list[end_node_id] + self._redirect_edge(start_node_id, end_node_id, new_node_id) + + def _block_end_node(self, layer_id, block_size): + ret = self.layer_id_to_output_node_ids[layer_id][0] + for _ in range(block_size - 2): + ret = self.adj_list[ret][0][0] + return ret + + def _dense_block_end_node(self, layer_id): + return self.layer_id_to_input_node_ids[layer_id][0] + + def _conv_block_end_node(self, layer_id): + """Get the input node ID of the last layer in the block by layer ID. + Return the input node ID of the last layer in the convolutional block. + Args: + layer_id: the convolutional layer ID. + """ + return self._block_end_node(layer_id, Constant.CONV_BLOCK_DISTANCE) + + def to_add_skip_model(self, start_id, end_id): + """Add a weighted add skip-connection from after start node to end node. + Args: + start_id: The convolutional layer ID, after which to start the skip-connection. + end_id: The convolutional layer ID, after which to end the skip-connection. + """ + self.operation_history.append(("to_add_skip_model", start_id, end_id)) + filters_end = self.layer_list[end_id].output.shape[-1] + filters_start = self.layer_list[start_id].output.shape[-1] + start_node_id = self.layer_id_to_output_node_ids[start_id][0] + + pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] + end_node_id = self.layer_id_to_output_node_ids[end_id][0] + + skip_output_id = self._insert_pooling_layer_chain( + start_node_id, end_node_id) + + # Add the conv layer in order to align the number of channels with end layer id + new_conv_layer = get_conv_class( + self.n_dim)( + filters_start, + filters_end, + 1) + skip_output_id = self.add_layer(new_conv_layer, skip_output_id) + + # Add the add layer. + add_input_node_id = self._add_node( + deepcopy(self.node_list[end_node_id])) + add_layer = StubAdd() + + self._redirect_edge(pre_end_node_id, end_node_id, add_input_node_id) + self._add_edge(add_layer, add_input_node_id, end_node_id) + self._add_edge(add_layer, skip_output_id, end_node_id) + add_layer.input = [ + self.node_list[add_input_node_id], + self.node_list[skip_output_id], + ] + add_layer.output = self.node_list[end_node_id] + self.node_list[end_node_id].shape = add_layer.output_shape + + # Set weights to the additional conv layer. + if self.weighted: + filter_shape = (1,) * self.n_dim + weights = np.zeros((filters_end, filters_start) + filter_shape) + bias = np.zeros(filters_end) + new_conv_layer.set_weights( + (add_noise(weights, np.array([0, 1])), add_noise( + bias, np.array([0, 1]))) + ) + + def to_concat_skip_model(self, start_id, end_id): + """Add a weighted add concatenate connection from after start node to end node. + Args: + start_id: The convolutional layer ID, after which to start the skip-connection. + end_id: The convolutional layer ID, after which to end the skip-connection. + """ + self.operation_history.append( + ("to_concat_skip_model", start_id, end_id)) + filters_end = self.layer_list[end_id].output.shape[-1] + filters_start = self.layer_list[start_id].output.shape[-1] + start_node_id = self.layer_id_to_output_node_ids[start_id][0] + + pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] + end_node_id = self.layer_id_to_output_node_ids[end_id][0] + + skip_output_id = self._insert_pooling_layer_chain( + start_node_id, end_node_id) + + concat_input_node_id = self._add_node( + deepcopy(self.node_list[end_node_id])) + self._redirect_edge(pre_end_node_id, end_node_id, concat_input_node_id) + + concat_layer = StubConcatenate() + concat_layer.input = [ + self.node_list[concat_input_node_id], + self.node_list[skip_output_id], + ] + concat_output_node_id = self._add_node(Node(concat_layer.output_shape)) + self._add_edge( + concat_layer, + concat_input_node_id, + concat_output_node_id) + self._add_edge(concat_layer, skip_output_id, concat_output_node_id) + concat_layer.output = self.node_list[concat_output_node_id] + self.node_list[concat_output_node_id].shape = concat_layer.output_shape + + # Add the concatenate layer. + # concat过channel数增加,用conv class 回到原先的channel数 + new_conv_layer = get_conv_class(self.n_dim)( + filters_start + filters_end, filters_end, 1 + ) + self._add_edge(new_conv_layer, concat_output_node_id, end_node_id) + new_conv_layer.input = self.node_list[concat_output_node_id] + new_conv_layer.output = self.node_list[end_node_id] + self.node_list[end_node_id].shape = new_conv_layer.output_shape + + if self.weighted: + filter_shape = (1,) * self.n_dim + weights = np.zeros((filters_end, filters_end) + filter_shape) + for i in range(filters_end): + filter_weight = np.zeros((filters_end,) + filter_shape) + center_index = (i,) + (0,) * self.n_dim + filter_weight[center_index] = 1 + weights[i, ...] = filter_weight + weights = np.concatenate( + (weights, np.zeros((filters_end, filters_start) + filter_shape)), axis=1 + ) + bias = np.zeros(filters_end) + new_conv_layer.set_weights( + (add_noise(weights, np.array([0, 1])), add_noise( + bias, np.array([0, 1]))) + ) + + def _insert_pooling_layer_chain(self, start_node_id, end_node_id): + """ + insert pooling layer + """ + skip_output_id = start_node_id + # 得到从start_node_id 到 end_node_id之间的所有pooling layer(包括conv layer stride > 1) + for layer in self._get_pooling_layers(start_node_id, end_node_id): + new_layer = deepcopy(layer) + # 如果是conv层需要重新初始化weights + if is_layer(new_layer, "Conv"): + # start node id 的通道数 + filters = self.node_list[start_node_id].shape[-1] + new_layer = get_conv_class(self.n_dim)( + filters, filters, 1, layer.stride) + if self.weighted: + init_conv_weight(new_layer) + else: + new_layer = deepcopy(layer) + skip_output_id = self.add_layer(new_layer, skip_output_id) + skip_output_id = self.add_layer(StubReLU(), skip_output_id) + return skip_output_id + + def extract_descriptor(self): + """Extract the the description of the Graph as an instance of NetworkDescriptor.""" + main_chain = self.get_main_chain() + index_in_main_chain = {} + for index, u in enumerate(main_chain): + index_in_main_chain[u] = index + + ret = NetworkDescriptor() + for u in main_chain: + for v, layer_id in self.adj_list[u]: + if v not in index_in_main_chain: + continue + layer = self.layer_list[layer_id] + copied_layer = copy(layer) + copied_layer.weights = None + ret.add_layer(deepcopy(copied_layer)) + + for u in index_in_main_chain: + for v, layer_id in self.adj_list[u]: + if v not in index_in_main_chain: + temp_u = u + temp_v = v + temp_layer_id = layer_id + skip_type = None + while not ( + temp_v in index_in_main_chain and temp_u in index_in_main_chain): + if is_layer( + self.layer_list[temp_layer_id], "Concatenate"): + skip_type = NetworkDescriptor.CONCAT_CONNECT + if is_layer(self.layer_list[temp_layer_id], "Add"): + skip_type = NetworkDescriptor.ADD_CONNECT + temp_u = temp_v + temp_v, temp_layer_id = self.adj_list[temp_v][0] + ret.add_skip_connection( + index_in_main_chain[u], index_in_main_chain[temp_u], skip_type + ) + + elif index_in_main_chain[v] - index_in_main_chain[u] != 1: + skip_type = None + if is_layer(self.layer_list[layer_id], "Concatenate"): + skip_type = NetworkDescriptor.CONCAT_CONNECT + if is_layer(self.layer_list[layer_id], "Add"): + skip_type = NetworkDescriptor.ADD_CONNECT + ret.add_skip_connection( + index_in_main_chain[u], index_in_main_chain[v], skip_type + ) + + return ret + + def clear_weights(self): + ''' clear weights of the graph + ''' + self.weighted = False + for layer in self.layer_list: + layer.weights = None + + def produce_torch_model(self): + """Build a new Torch model based on the current graph.""" + return TorchModel(self) + + def produce_json_model(self): + """Build a new Json model based on the current graph.""" + return JSONModel(self).data + + @classmethod + def parsing_json_model(cls, json_model): + '''build a graph from json + ''' + return json_to_graph(json_model) + + def _layer_ids_in_order(self, layer_ids): + node_id_to_order_index = {} + for index, node_id in enumerate(self.topological_order): + node_id_to_order_index[node_id] = index + return sorted( + layer_ids, + key=lambda layer_id: node_id_to_order_index[ + self.layer_id_to_output_node_ids[layer_id][0] + ], + ) + + def _layer_ids_by_type(self, type_str): + return list( + filter( + lambda layer_id: is_layer(self.layer_list[layer_id], type_str), + range(self.n_layers), + ) + ) + + def get_main_chain_layers(self): + """Return a list of layer IDs in the main chain.""" + main_chain = self.get_main_chain() + ret = [] + for u in main_chain: + for v, layer_id in self.adj_list[u]: + if v in main_chain and u in main_chain: + ret.append(layer_id) + return ret + + def _conv_layer_ids_in_order(self): + return list( + filter( + lambda layer_id: is_layer(self.layer_list[layer_id], "Conv"), + self.get_main_chain_layers(), + ) + ) + + def _dense_layer_ids_in_order(self): + return self._layer_ids_in_order(self._layer_ids_by_type("Dense")) + + def deep_layer_ids(self): + ret = [] + for layer_id in self.get_main_chain_layers(): + layer = self.layer_list[layer_id] + # GAP之后就不插入layer了 + if is_layer(layer, "GlobalAveragePooling"): + break + if is_layer(layer, "Add") or is_layer(layer, "Concatenate"): + continue + ret.append(layer_id) + return ret + + def wide_layer_ids(self): + return ( + self._conv_layer_ids_in_order( + )[:-1] + self._dense_layer_ids_in_order()[:-1] + ) + + def skip_connection_layer_ids(self): + return self.deep_layer_ids()[:-1] + + def size(self): + return sum(list(map(lambda x: x.size(), self.layer_list))) + + def get_main_chain(self): + """Returns the main chain node ID list.""" + pre_node = {} + distance = {} + + # 初始化每个节点距离为0,他的前一个节点为自己 + for i in range(self.n_nodes): + distance[i] = 0 + pre_node[i] = i + + # 遍历所有节点,根据邻接表找到他的前一个节点以及他本身的位置 + for i in range(self.n_nodes - 1): + for u in range(self.n_nodes): + for v, _ in self.adj_list[u]: + if distance[u] + 1 > distance[v]: + distance[v] = distance[u] + 1 + pre_node[v] = u + + # temp_id记录距离最大的node + temp_id = 0 + for i in range(self.n_nodes): + if distance[i] > distance[temp_id]: + temp_id = i + + # 从距离最大的node开始不断找到他的前一个节点,最终找到主链 + ret = [] + for i in range(self.n_nodes + 5): + ret.append(temp_id) + if pre_node[temp_id] == temp_id: + break + temp_id = pre_node[temp_id] + assert temp_id == pre_node[temp_id] + ret.reverse() + return ret + + +class TorchModel(torch.nn.Module): + """A neural network class using pytorch constructed from an instance of Graph.""" + + def __init__(self, graph): + super(TorchModel, self).__init__() + self.graph = graph + self.layers = torch.nn.ModuleList() + for layer in graph.layer_list: + self.layers.append(layer.to_real_layer()) + if graph.weighted: + for index, layer in enumerate(self.layers): + set_stub_weight_to_torch(self.graph.layer_list[index], layer) + for index, layer in enumerate(self.layers): + self.add_module(str(index), layer) + + def forward(self, input_tensor): + topo_node_list = self.graph.topological_order + output_id = topo_node_list[-1] + input_id = topo_node_list[0] + + node_list = deepcopy(self.graph.node_list) + node_list[input_id] = input_tensor + + for v in topo_node_list: + for u, layer_id in self.graph.reverse_adj_list[v]: + layer = self.graph.layer_list[layer_id] + torch_layer = self.layers[layer_id] + + if isinstance(layer, (StubAdd, StubConcatenate)): + edge_input_tensor = list( + map( + lambda x: node_list[x], + self.graph.layer_id_to_input_node_ids[layer_id], + ) + ) + else: + edge_input_tensor = node_list[u] + + temp_tensor = torch_layer(edge_input_tensor) + node_list[v] = temp_tensor + return node_list[output_id] + + def set_weight_to_graph(self): + self.graph.weighted = True + for index, layer in enumerate(self.layers): + set_torch_weight_to_stub(layer, self.graph.layer_list[index]) + + +class JSONModel: + def __init__(self, graph): + data = dict() + node_list = list() + layer_list = list() + operation_history = list() + + data["input_shape"] = graph.input_shape + vis = graph.vis + data["vis"] = list(vis.keys()) if vis is not None else None + data["weighted"] = graph.weighted + + for item in graph.operation_history: + if item[0] == "to_deeper_model": + operation_history.append( + [ + item[0], + item[1], + layer_description_extractor(item[2], graph.node_to_id), + ] + ) + else: + operation_history.append(item) + data["operation_history"] = operation_history + data["layer_id_to_input_node_ids"] = graph.layer_id_to_input_node_ids + data["layer_id_to_output_node_ids"] = graph.layer_id_to_output_node_ids + data["adj_list"] = graph.adj_list + data["reverse_adj_list"] = graph.reverse_adj_list + + for node in graph.node_list: + node_id = graph.node_to_id[node] + node_information = node.shape + node_list.append((node_id, node_information)) + + for layer_id, item in enumerate(graph.layer_list): + layer = graph.layer_list[layer_id] + layer_information = layer_description_extractor( + layer, graph.node_to_id) + layer_list.append((layer_id, layer_information)) + + data["node_list"] = node_list + data["layer_list"] = layer_list + + self.data = data + + +def graph_to_json(graph, json_model_path): + json_out = graph.produce_json_model() + with open(json_model_path, "w") as fout: + json.dump(json_out, fout) + json_out = json.dumps(json_out) + return json_out + + +def json_to_graph(json_model: str): + json_model = json.loads(json_model) + # restore graph data from json data + input_shape = tuple(json_model["input_shape"]) + node_list = list() + node_to_id = dict() + id_to_node = dict() + layer_list = list() + layer_to_id = dict() + operation_history = list() + graph = Graph(input_shape, False) + + graph.input_shape = input_shape + vis = json_model["vis"] + graph.vis = { + tuple(item): True for item in vis} if vis is not None else None + graph.weighted = json_model["weighted"] + layer_id_to_input_node_ids = json_model["layer_id_to_input_node_ids"] + graph.layer_id_to_input_node_ids = { + int(k): v for k, v in layer_id_to_input_node_ids.items() + } + layer_id_to_output_node_ids = json_model["layer_id_to_output_node_ids"] + graph.layer_id_to_output_node_ids = { + int(k): v for k, v in layer_id_to_output_node_ids.items() + } + adj_list = {} + for k, v in json_model["adj_list"].items(): + adj_list[int(k)] = [tuple(i) for i in v] + graph.adj_list = adj_list + reverse_adj_list = {} + for k, v in json_model["reverse_adj_list"].items(): + reverse_adj_list[int(k)] = [tuple(i) for i in v] + graph.reverse_adj_list = reverse_adj_list + + for item in json_model["node_list"]: + new_node = Node(tuple(item[1])) + node_id = item[0] + node_list.append(new_node) + node_to_id[new_node] = node_id + id_to_node[node_id] = new_node + + for item in json_model["operation_history"]: + if item[0] == "to_deeper_model": + operation_history.append( + (item[0], item[1], layer_description_builder(item[2], id_to_node)) + ) + else: + operation_history.append(item) + graph.operation_history = operation_history + + for item in json_model["layer_list"]: + new_layer = layer_description_builder(item[1], id_to_node) + layer_id = int(item[0]) + layer_list.append(new_layer) + layer_to_id[new_layer] = layer_id + + graph.node_list = node_list + graph.node_to_id = node_to_id + graph.layer_list = layer_list + graph.layer_to_id = layer_to_id + + return graph diff --git a/dubhe-tadl/network_morphism/algorithm/graph_transformer.py b/dubhe-tadl/network_morphism/algorithm/graph_transformer.py new file mode 100644 index 0000000..623bf7d --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/graph_transformer.py @@ -0,0 +1,178 @@ +from copy import deepcopy + +from random import randrange, sample + +from .graph import NetworkDescriptor +from .layers import ( + StubDense, + StubReLU, + get_batch_norm_class, + get_conv_class, + get_dropout_class, + get_pooling_class, + is_layer, +) +from utils import Constant + + +def to_wider_graph(graph): + ''' wider graph + ''' + weighted_layer_ids = graph.wide_layer_ids() + weighted_layer_ids = list( + filter( + lambda x: graph.layer_list[x].output.shape[-1], weighted_layer_ids) + ) + wider_layers = sample(weighted_layer_ids, 1) + + # count the number of layers with width larger than the max width + layer_width_maxed = 0 + + for layer_id in wider_layers: + layer = graph.layer_list[layer_id] + if is_layer(layer, "Conv"): + n_add = layer.filters + else: + n_add = layer.units + + if n_add*2 > Constant.MAX_LAYER_WIDTH: + layer_width_maxed += 1 + continue + + graph.to_wider_model(layer_id, n_add) + + if layer_width_maxed == len(wider_layers): + return None + return graph + + +def to_skip_connection_graph(graph): + ''' skip connection graph + ''' + # The last conv layer cannot be widen since wider operator cannot be done + # over the two sides of flatten. + weighted_layer_ids = graph.skip_connection_layer_ids() + valid_connection = [] + for skip_type in sorted( + [NetworkDescriptor.ADD_CONNECT, NetworkDescriptor.CONCAT_CONNECT]): + for index_a in range(len(weighted_layer_ids)): + for index_b in range(len(weighted_layer_ids))[index_a + 1:]: + valid_connection.append((index_a, index_b, skip_type)) + + if len(valid_connection) < 1: + return graph + for index_a, index_b, skip_type in sample(valid_connection, 1): + a_id = weighted_layer_ids[index_a] + b_id = weighted_layer_ids[index_b] + if skip_type == NetworkDescriptor.ADD_CONNECT: + graph.to_add_skip_model(a_id, b_id) + else: + graph.to_concat_skip_model(a_id, b_id) + return graph + + +def create_new_layer(layer, n_dim): + ''' create new layer for the graph + ''' + + input_shape = layer.output.shape + # 一般情况 + dense_deeper_classes = [StubDense, get_dropout_class(n_dim), StubReLU] + conv_deeper_classes = [ + get_conv_class(n_dim), + get_batch_norm_class(n_dim), + StubReLU] + # 三种情况有特别的layer class + if is_layer(layer, "ReLU"): + conv_deeper_classes = [ + get_conv_class(n_dim), + get_batch_norm_class(n_dim)] + dense_deeper_classes = [StubDense, get_dropout_class(n_dim)] + elif is_layer(layer, "Dropout"): + dense_deeper_classes = [StubDense, StubReLU] + elif is_layer(layer, "BatchNormalization"): + conv_deeper_classes = [get_conv_class(n_dim), StubReLU] + + layer_class = None + if len(input_shape) == 1: + # It is in the dense layer part. + layer_class = sample(dense_deeper_classes, 1)[0] + else: + # It is in the conv layer part. + layer_class = sample(conv_deeper_classes, 1)[0] + + if layer_class == StubDense: + new_layer = StubDense(input_shape[0], input_shape[0]) + + elif layer_class == get_dropout_class(n_dim): + new_layer = layer_class(Constant.DENSE_DROPOUT_RATE) + + elif layer_class == get_conv_class(n_dim): + new_layer = layer_class( + input_shape[-1], input_shape[-1], sample((1, 3, 5), 1)[0], stride=1 + ) + + elif layer_class == get_batch_norm_class(n_dim): + new_layer = layer_class(input_shape[-1]) + + elif layer_class == get_pooling_class(n_dim): + new_layer = layer_class(sample((1, 3, 5), 1)[0]) + + else: + new_layer = layer_class() + + return new_layer + + +def to_deeper_graph(graph): + ''' deeper graph + ''' + + weighted_layer_ids = graph.deep_layer_ids() + if len(weighted_layer_ids) >= Constant.MAX_LAYERS: + return None + + deeper_layer_ids = sample(weighted_layer_ids, 1) + + for layer_id in deeper_layer_ids: + layer = graph.layer_list[layer_id] + new_layer = create_new_layer(layer, graph.n_dim) + graph.to_deeper_model(layer_id, new_layer) + return graph + + +def legal_graph(graph): + '''judge if a graph is legal or not. + ''' + + descriptor = graph.extract_descriptor() + skips = descriptor.skip_connections + if len(skips) != len(set(skips)): + return False + return True + + +# morph f with operations in O +def transform(graph): + '''core transform function for graph. + ''' + + graphs = [] + for _ in range(Constant.N_NEIGHBOURS * 2): + random_num = randrange(3) + temp_graph = None + if random_num == 0: + temp_graph = to_deeper_graph(deepcopy(graph)) + elif random_num == 1: + temp_graph = to_wider_graph(deepcopy(graph)) + elif random_num == 2: + temp_graph = to_skip_connection_graph(deepcopy(graph)) + + if temp_graph is not None and temp_graph.size() <= Constant.MAX_MODEL_SIZE: + graphs.append(temp_graph) + + # 最多8次操作 + if len(graphs) >= Constant.N_NEIGHBOURS: + break + + return graphs diff --git a/dubhe-tadl/network_morphism/algorithm/layer_transformer.py b/dubhe-tadl/network_morphism/algorithm/layer_transformer.py new file mode 100644 index 0000000..1ff57eb --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/layer_transformer.py @@ -0,0 +1,213 @@ +import numpy as np + +from .layers import ( + StubDense, + get_batch_norm_class, + get_conv_class, + get_n_dim, +) + +NOISE_RATIO = 1e-4 + +def wider_pre_dense(layer, n_add, weighted=True): + '''wider previous dense layer. + ''' + if not weighted: + return StubDense(layer.input_units, layer.units + n_add) + + n_units2 = layer.units + + teacher_w, teacher_b = layer.get_weights() + rand = np.random.randint(n_units2, size=n_add) + student_w = teacher_w.copy() + student_b = teacher_b.copy() + + # target layer update (i) + for i in range(n_add): + teacher_index = rand[i] + new_weight = teacher_w[teacher_index, :] + new_weight = new_weight[np.newaxis, :] + student_w = np.concatenate( + (student_w, add_noise(new_weight, student_w)), axis=0) + student_b = np.append( + student_b, add_noise( + teacher_b[teacher_index], student_b)) + + new_pre_layer = StubDense(layer.input_units, n_units2 + n_add) + new_pre_layer.set_weights((student_w, student_b)) + + return new_pre_layer + + +def wider_pre_conv(layer, n_add_filters, weighted=True): + '''wider previous conv layer. + ''' + n_dim = get_n_dim(layer) + if not weighted: + return get_conv_class(n_dim)( + layer.input_channel, + layer.filters + n_add_filters, + kernel_size=layer.kernel_size, + stride=layer.stride + ) + + n_pre_filters = layer.filters + rand = np.random.randint(n_pre_filters, size=n_add_filters) + teacher_w, teacher_b = layer.get_weights() + + student_w = teacher_w.copy() + student_b = teacher_b.copy() + # target layer update (i) + for i in range(len(rand)): + teacher_index = rand[i] + new_weight = teacher_w[teacher_index, ...] + new_weight = new_weight[np.newaxis, ...] + student_w = np.concatenate((student_w, new_weight), axis=0) + student_b = np.append(student_b, teacher_b[teacher_index]) + new_pre_layer = get_conv_class(n_dim)( + layer.input_channel, + n_pre_filters + n_add_filters, + kernel_size=layer.kernel_size, + stride=layer.stride + ) + new_pre_layer.set_weights( + (add_noise(student_w, teacher_w), add_noise(student_b, teacher_b)) + ) + return new_pre_layer + + +def wider_next_conv(layer, start_dim, total_dim, n_add, weighted=True): + '''wider next conv layer. + ''' + n_dim = get_n_dim(layer) + if not weighted: + return get_conv_class(n_dim)(layer.input_channel + n_add, + layer.filters, + kernel_size=layer.kernel_size, + stride=layer.stride) + n_filters = layer.filters + teacher_w, teacher_b = layer.get_weights() + + new_weight_shape = list(teacher_w.shape) + new_weight_shape[1] = n_add + new_weight = np.zeros(tuple(new_weight_shape)) + + student_w = np.concatenate((teacher_w[:, :start_dim, ...].copy(), + add_noise(new_weight, teacher_w), + teacher_w[:, start_dim:total_dim, ...].copy()), axis=1) + new_layer = get_conv_class(n_dim)(layer.input_channel + n_add, + n_filters, + kernel_size=layer.kernel_size, + stride=layer.stride) + new_layer.set_weights((student_w, teacher_b)) + return new_layer + + +def wider_bn(layer, start_dim, total_dim, n_add, weighted=True): + '''wider batch norm layer. + ''' + n_dim = get_n_dim(layer) + if not weighted: + return get_batch_norm_class(n_dim)(layer.num_features + n_add) + + weights = layer.get_weights() + + new_weights = [ + add_noise(np.ones(n_add, dtype=np.float32), np.array([0, 1])), + add_noise(np.zeros(n_add, dtype=np.float32), np.array([0, 1])), + add_noise(np.zeros(n_add, dtype=np.float32), np.array([0, 1])), + add_noise(np.ones(n_add, dtype=np.float32), np.array([0, 1])), + ] + + student_w = tuple() + for weight, new_weight in zip(weights, new_weights): + temp_w = weight.copy() + temp_w = np.concatenate( + (temp_w[:start_dim], new_weight, temp_w[start_dim:total_dim]) + ) + student_w += (temp_w,) + new_layer = get_batch_norm_class(n_dim)(layer.num_features + n_add) + new_layer.set_weights(student_w) + return new_layer + + +def wider_next_dense(layer, start_dim, total_dim, n_add, weighted=True): + '''wider next dense layer. + ''' + if not weighted: + return StubDense(layer.input_units + n_add, layer.units) + teacher_w, teacher_b = layer.get_weights() + student_w = teacher_w.copy() + n_units_each_channel = int(teacher_w.shape[1] / total_dim) + + new_weight = np.zeros((teacher_w.shape[0], n_add * n_units_each_channel)) + student_w = np.concatenate( + ( + student_w[:, : start_dim * n_units_each_channel], + add_noise(new_weight, student_w), + student_w[ + :, start_dim * n_units_each_channel: total_dim * n_units_each_channel + ], + ), + axis=1, + ) + + new_layer = StubDense(layer.input_units + n_add, layer.units) + new_layer.set_weights((student_w, teacher_b)) + return new_layer + + +def add_noise(weights, other_weights): + '''add noise to the layer. + ''' + w_range = np.ptp(other_weights.flatten()) + noise_range = NOISE_RATIO * w_range + noise = np.random.uniform(-noise_range / 2.0, + noise_range / 2.0, weights.shape) + return np.add(noise, weights) + + +def init_dense_weight(layer): + '''initilize dense layer weight. + ''' + units = layer.units + weight = np.eye(units) + bias = np.zeros(units) + layer.set_weights( + (add_noise(weight, np.array([0, 1])), + add_noise(bias, np.array([0, 1]))) + ) + + +def init_conv_weight(layer): + '''initilize conv layer weight. + ''' + n_filters = layer.filters + filter_shape = (layer.kernel_size,) * get_n_dim(layer) + weight = np.zeros((n_filters, n_filters) + filter_shape) + + center = tuple(map(lambda x: int((x - 1) / 2), filter_shape)) + for i in range(n_filters): + filter_weight = np.zeros((n_filters,) + filter_shape) + index = (i,) + center + filter_weight[index] = 1 + weight[i, ...] = filter_weight + bias = np.zeros(n_filters) + + layer.set_weights( + (add_noise(weight, np.array([0, 1])), + add_noise(bias, np.array([0, 1]))) + ) + + +def init_bn_weight(layer): + '''initilize batch norm layer weight. + ''' + n_filters = layer.num_features + new_weights = [ + add_noise(np.ones(n_filters, dtype=np.float32), np.array([0, 1])), + add_noise(np.zeros(n_filters, dtype=np.float32), np.array([0, 1])), + add_noise(np.zeros(n_filters, dtype=np.float32), np.array([0, 1])), + add_noise(np.ones(n_filters, dtype=np.float32), np.array([0, 1])), + ] + layer.set_weights(new_weights) diff --git a/dubhe-tadl/network_morphism/algorithm/layers.py b/dubhe-tadl/network_morphism/algorithm/layers.py new file mode 100644 index 0000000..7e81ae7 --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/layers.py @@ -0,0 +1,765 @@ +from abc import abstractmethod +from collections.abc import Iterable + +import torch +from torch import nn +from torch.nn import functional +from utils import Constant + + +class AvgPool(nn.Module): + """ + AvgPool Module. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, input_tensor): + pass + + +class GlobalAvgPool1d(AvgPool): + """ + GlobalAvgPool1d Module. + """ + + def forward(self, input_tensor): + return functional.avg_pool1d(input_tensor, input_tensor.size()[2:]).view( + input_tensor.size()[:2] + ) + + +class GlobalAvgPool2d(AvgPool): + """ + GlobalAvgPool2d Module. + """ + + def forward(self, input_tensor): + return functional.avg_pool2d(input_tensor, input_tensor.size()[2:]).view( + input_tensor.size()[:2] + ) + + +class GlobalAvgPool3d(AvgPool): + """ + GlobalAvgPool3d Module. + """ + + def forward(self, input_tensor): + return functional.avg_pool3d(input_tensor, input_tensor.size()[2:]).view( + input_tensor.size()[:2] + ) + + +class StubLayer: + """ + StubLayer Module. Base Module. + """ + + def __init__(self, input_node=None, output_node=None): + self.input = input_node + self.output = output_node + self.weights = None + + def build(self, shape): + """ + build shape. + """ + + def set_weights(self, weights): + """ + set weights. + """ + self.weights = weights + + def import_weights(self, torch_layer): + """ + import weights. + """ + + def export_weights(self, torch_layer): + """ + export weights. + """ + + def get_weights(self): + """ + get weights. + """ + return self.weights + + def size(self): + """ + size(). + """ + return 0 + + @property + def output_shape(self): + """ + output shape. + """ + return self.input.shape + + def to_real_layer(self): + """ + to real layer. + """ + + def __str__(self): + """ + str() function to print. + """ + return type(self).__name__[4:] + + +class StubWeightBiasLayer(StubLayer): + """ + StubWeightBiasLayer Module to set the bias. + """ + + def import_weights(self, torch_layer): + self.set_weights( + (torch_layer.weight.data.cpu().numpy(), + torch_layer.bias.data.cpu().numpy()) + ) + + def export_weights(self, torch_layer): + torch_layer.weight.data = torch.Tensor(self.weights[0]) + torch_layer.bias.data = torch.Tensor(self.weights[1]) + + + +class StubBatchNormalization(StubWeightBiasLayer): + """ + StubBatchNormalization Module. Batch Norm. + """ + + def __init__(self, num_features, input_node=None, output_node=None): + super().__init__(input_node, output_node) + self.num_features = num_features + + def import_weights(self, torch_layer): + self.set_weights( + ( + torch_layer.weight.data.cpu().numpy(), + torch_layer.bias.data.cpu().numpy(), + torch_layer.running_mean.cpu().numpy(), + torch_layer.running_var.cpu().numpy(), + ) + ) + + def export_weights(self, torch_layer): + torch_layer.weight.data = torch.Tensor(self.weights[0]) + torch_layer.bias.data = torch.Tensor(self.weights[1]) + torch_layer.running_mean = torch.Tensor(self.weights[2]) + torch_layer.running_var = torch.Tensor(self.weights[3]) + + def size(self): + return self.num_features * 4 + + @abstractmethod + def to_real_layer(self): + pass + + +class StubBatchNormalization1d(StubBatchNormalization): + """ + StubBatchNormalization1d Module. + """ + + def to_real_layer(self): + return torch.nn.BatchNorm1d(self.num_features) + + +class StubBatchNormalization2d(StubBatchNormalization): + """ + StubBatchNormalization2d Module. + """ + + def to_real_layer(self): + return torch.nn.BatchNorm2d(self.num_features) + + +class StubBatchNormalization3d(StubBatchNormalization): + """ + StubBatchNormalization3d Module. + """ + + def to_real_layer(self): + return torch.nn.BatchNorm3d(self.num_features) + + +class StubDense(StubWeightBiasLayer): + """ + StubDense Module. Linear. + """ + + def __init__(self, input_units, units, input_node=None, output_node=None): + super().__init__(input_node, output_node) + self.input_units = input_units + self.units = units + + @property + def output_shape(self): + return (self.units,) + + def size(self): + return self.input_units * self.units + self.units + + def to_real_layer(self): + return torch.nn.Linear(self.input_units, self.units) + + +class StubConv(StubWeightBiasLayer): + """ + StubConv Module. Conv. + """ + + def __init__(self, input_channel, filters, kernel_size, + stride=1, input_node=None, output_node=None): + super().__init__(input_node, output_node) + self.input_channel = input_channel + self.filters = filters + self.kernel_size = kernel_size + self.stride = stride + self.padding = int(self.kernel_size / 2) + + @property + def output_shape(self): + ret = list(self.input.shape[:-1]) + for index, dim in enumerate(ret): + ret[index] = ( + int((dim + 2 * self.padding - self.kernel_size) / self.stride) + 1 + ) + ret = ret + [self.filters] + return tuple(ret) + + def size(self): + return (self.input_channel * self.kernel_size * + self.kernel_size + 1) * self.filters + + @abstractmethod + def to_real_layer(self): + pass + + def __str__(self): + return ( + super().__str__() + + "(" + + ", ".join( + str(item) + for item in [ + self.input_channel, + self.filters, + self.kernel_size, + self.stride, + ] + ) + + ")" + ) + + +class StubConv1d(StubConv): + """ + StubConv1d Module. + """ + + def to_real_layer(self): + return torch.nn.Conv1d( + self.input_channel, + self.filters, + self.kernel_size, + stride=self.stride, + padding=self.padding, + ) + + +class StubConv2d(StubConv): + """ + StubConv2d Module. + """ + + def to_real_layer(self): + return torch.nn.Conv2d( + self.input_channel, + self.filters, + self.kernel_size, + stride=self.stride, + padding=self.padding, + ) + + +class StubConv3d(StubConv): + """ + StubConv3d Module. + """ + + def to_real_layer(self): + return torch.nn.Conv3d( + self.input_channel, + self.filters, + self.kernel_size, + stride=self.stride, + padding=self.padding, + ) + + +class StubAggregateLayer(StubLayer): + """ + StubAggregateLayer Module. + """ + + def __init__(self, input_nodes=None, output_node=None): + if input_nodes is None: + input_nodes = [] + super().__init__(input_nodes, output_node) + + +class StubConcatenate(StubAggregateLayer): + """StubConcatenate Module. + """ + @property + def output_shape(self): + ret = 0 + for current_input in self.input: + ret += current_input.shape[-1] + ret = self.input[0].shape[:-1] + (ret,) + return ret + + def to_real_layer(self): + return TorchConcatenate() + + +class StubAdd(StubAggregateLayer): + """ + StubAdd Module. + """ + @property + def output_shape(self): + return self.input[0].shape + + def to_real_layer(self): + return TorchAdd() + + +class StubFlatten(StubLayer): + """ + StubFlatten Module. + """ + @property + def output_shape(self): + ret = 1 + for dim in self.input.shape: + ret *= dim + return (ret,) + + def to_real_layer(self): + return TorchFlatten() + + +class StubReLU(StubLayer): + """ + StubReLU Module. + """ + + def to_real_layer(self): + return torch.nn.ReLU() + + +class StubSoftmax(StubLayer): + """ + StubSoftmax Module. + """ + + def to_real_layer(self): + return torch.nn.LogSoftmax(dim=1) + + +class StubDropout(StubLayer): + """ + StubDropout Module. + """ + + def __init__(self, rate, input_node=None, output_node=None): + super().__init__(input_node, output_node) + self.rate = rate + + @abstractmethod + def to_real_layer(self): + pass + + +class StubDropout1d(StubDropout): + """ + StubDropout1d Module. + """ + + def to_real_layer(self): + return torch.nn.Dropout(self.rate) + + +class StubDropout2d(StubDropout): + """ + StubDropout2d Module. + """ + + def to_real_layer(self): + return torch.nn.Dropout2d(self.rate) + + +class StubDropout3d(StubDropout): + """ + StubDropout3d Module. + """ + + def to_real_layer(self): + return torch.nn.Dropout3d(self.rate) + + +class StubInput(StubLayer): + """ + StubInput Module. + """ + + def __init__(self, input_node=None, output_node=None): + super().__init__(input_node, output_node) + + +class StubPooling(StubLayer): + """ + StubPooling Module. + """ + + def __init__(self, + kernel_size=None, + stride=None, + padding=0, + input_node=None, + output_node=None): + super().__init__(input_node, output_node) + self.kernel_size = ( + kernel_size if kernel_size is not None else Constant.POOLING_KERNEL_SIZE + ) + self.stride = stride if stride is not None else self.kernel_size + self.padding = padding + + @property + def output_shape(self): + ret = tuple() + for dim in self.input.shape[:-1]: + ret = ret + (max(int((dim + 2 * self.padding) / self.kernel_size), 1),) + ret = ret + (self.input.shape[-1],) + return ret + + @abstractmethod + def to_real_layer(self): + pass + + +class StubPooling1d(StubPooling): + """ + StubPooling1d Module. + """ + + def to_real_layer(self): + return torch.nn.MaxPool1d(self.kernel_size, stride=self.stride) + + +class StubPooling2d(StubPooling): + """ + StubPooling2d Module. + """ + + def to_real_layer(self): + return torch.nn.MaxPool2d(self.kernel_size, stride=self.stride) + + +class StubPooling3d(StubPooling): + """ + StubPooling3d Module. + """ + + def to_real_layer(self): + return torch.nn.MaxPool3d(self.kernel_size, stride=self.stride) + + +class StubGlobalPooling(StubLayer): + """ + StubGlobalPooling Module. + """ + + def __init__(self, input_node=None, output_node=None): + super().__init__(input_node, output_node) + + @property + def output_shape(self): + return (self.input.shape[-1],) + + @abstractmethod + def to_real_layer(self): + pass + + +class StubGlobalPooling1d(StubGlobalPooling): + """ + StubGlobalPooling1d Module. + """ + + def to_real_layer(self): + return GlobalAvgPool1d() + + +class StubGlobalPooling2d(StubGlobalPooling): + """ + StubGlobalPooling2d Module. + """ + + def to_real_layer(self): + return GlobalAvgPool2d() + + +class StubGlobalPooling3d(StubGlobalPooling): + """ + StubGlobalPooling3d Module. + """ + + def to_real_layer(self): + return GlobalAvgPool3d() + + +class TorchConcatenate(nn.Module): + """ + TorchConcatenate Module. + """ + + def forward(self, input_list): + return torch.cat(input_list, dim=1) + + +class TorchAdd(nn.Module): + """ + TorchAdd Module. + """ + + def forward(self, input_list): + return input_list[0] + input_list[1] + + +class TorchFlatten(nn.Module): + """ + TorchFlatten Module. + """ + + def forward(self, input_tensor): + return input_tensor.view(input_tensor.size(0), -1) + +def is_layer(layer, layer_type): + """ + Judge the layer type. + + Returns + ------- + bool + boolean -- True or False + """ + + if layer_type == "Input": + return isinstance(layer, StubInput) + elif layer_type == "Conv": + return isinstance(layer, StubConv) + elif layer_type == "Dense": + return isinstance(layer, (StubDense,)) + elif layer_type == "BatchNormalization": + return isinstance(layer, (StubBatchNormalization,)) + elif layer_type == "Concatenate": + return isinstance(layer, (StubConcatenate,)) + elif layer_type == "Add": + return isinstance(layer, (StubAdd,)) + elif layer_type == "Pooling": + return isinstance(layer, StubPooling) + elif layer_type == "Dropout": + return isinstance(layer, (StubDropout,)) + elif layer_type == "Softmax": + return isinstance(layer, (StubSoftmax,)) + elif layer_type == "ReLU": + return isinstance(layer, (StubReLU,)) + elif layer_type == "Flatten": + return isinstance(layer, (StubFlatten,)) + elif layer_type == "GlobalAveragePooling": + return isinstance(layer, StubGlobalPooling) + return None # note: this is not written by original author, feel free to modify if you think it's incorrect + + +def layer_description_extractor(layer, node_to_id): + """ + Get layer description. + """ + + layer_input = layer.input + layer_output = layer.output + if layer_input is not None: + if isinstance(layer_input, Iterable): + layer_input = list(map(lambda x: node_to_id[x], layer_input)) + else: + layer_input = node_to_id[layer_input] + + if layer_output is not None: + layer_output = node_to_id[layer_output] + + if isinstance(layer, StubConv): + return ( + type(layer).__name__, + layer_input, + layer_output, + layer.input_channel, + layer.filters, + layer.kernel_size, + layer.stride, + layer.padding, + ) + elif isinstance(layer, (StubDense,)): + return [ + type(layer).__name__, + layer_input, + layer_output, + layer.input_units, + layer.units, + ] + elif isinstance(layer, (StubBatchNormalization,)): + return (type(layer).__name__, layer_input, + layer_output, layer.num_features) + elif isinstance(layer, (StubDropout,)): + return (type(layer).__name__, layer_input, layer_output, layer.rate) + elif isinstance(layer, StubPooling): + return ( + type(layer).__name__, + layer_input, + layer_output, + layer.kernel_size, + layer.stride, + layer.padding, + ) + else: + return (type(layer).__name__, layer_input, layer_output) + + +def layer_description_builder(layer_information, id_to_node): + """build layer from description. + """ + layer_type = layer_information[0] + + layer_input_ids = layer_information[1] + if isinstance(layer_input_ids, Iterable): + layer_input = list(map(lambda x: id_to_node[x], layer_input_ids)) + else: + layer_input = id_to_node[layer_input_ids] + layer_output = id_to_node[layer_information[2]] + if layer_type.startswith("StubConv"): + input_channel = layer_information[3] + filters = layer_information[4] + kernel_size = layer_information[5] + stride = layer_information[6] + return globals()[layer_type]( + input_channel, filters, kernel_size, stride, layer_input, layer_output + ) + elif layer_type.startswith("StubDense"): + input_units = layer_information[3] + units = layer_information[4] + return globals()[layer_type](input_units, units, layer_input, layer_output) + elif layer_type.startswith("StubBatchNormalization"): + num_features = layer_information[3] + return globals()[layer_type](num_features, layer_input, layer_output) + elif layer_type.startswith("StubDropout"): + rate = layer_information[3] + return globals()[layer_type](rate, layer_input, layer_output) + elif layer_type.startswith("StubPooling"): + kernel_size = layer_information[3] + stride = layer_information[4] + padding = layer_information[5] + return globals()[layer_type](kernel_size, stride, padding, layer_input, layer_output) + else: + return globals()[layer_type](layer_input, layer_output) + + +def layer_width(layer): + """ + Get layer width. + """ + + if is_layer(layer, "Dense"): + return layer.units + if is_layer(layer, "Conv"): + return layer.filters + raise TypeError("The layer should be either Dense or Conv layer.") + + +def set_torch_weight_to_stub(torch_layer, stub_layer): + stub_layer.import_weights(torch_layer) + + +def set_stub_weight_to_torch(stub_layer, torch_layer): + stub_layer.export_weights(torch_layer) + + +def get_conv_class(n_dim): + conv_class_list = [StubConv1d, StubConv2d, StubConv3d] + return conv_class_list[n_dim - 1] + + +def get_dropout_class(n_dim): + dropout_class_list = [StubDropout1d, StubDropout2d, StubDropout3d] + return dropout_class_list[n_dim - 1] + + +def get_global_avg_pooling_class(n_dim): + global_avg_pooling_class_list = [ + StubGlobalPooling1d, + StubGlobalPooling2d, + StubGlobalPooling3d, + ] + return global_avg_pooling_class_list[n_dim - 1] + + +def get_pooling_class(n_dim): + pooling_class_list = [StubPooling1d, StubPooling2d, StubPooling3d] + return pooling_class_list[n_dim - 1] + + +def get_batch_norm_class(n_dim): + batch_norm_class_list = [ + StubBatchNormalization1d, + StubBatchNormalization2d, + StubBatchNormalization3d, + ] + return batch_norm_class_list[n_dim - 1] + + +def get_n_dim(layer): + if isinstance(layer, ( + StubConv1d, + StubDropout1d, + StubGlobalPooling1d, + StubPooling1d, + StubBatchNormalization1d, + )): + return 1 + if isinstance(layer, ( + StubConv2d, + StubDropout2d, + StubGlobalPooling2d, + StubPooling2d, + StubBatchNormalization2d, + )): + return 2 + if isinstance(layer, ( + StubConv3d, + StubDropout3d, + StubGlobalPooling3d, + StubPooling3d, + StubBatchNormalization3d, + )): + return 3 + return -1 diff --git a/dubhe-tadl/network_morphism/algorithm/networkmorphism_searcher.py b/dubhe-tadl/network_morphism/algorithm/networkmorphism_searcher.py new file mode 100644 index 0000000..192d4d9 --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/networkmorphism_searcher.py @@ -0,0 +1,293 @@ +import logging +import os +import shutil +from utils import Constant, OptimizeMode +from .bayesian import BayesianOptimizer +from .nn import CnnGenerator, ResNetGenerator, MlpGenerator +from .graph import graph_to_json, json_to_graph + +logger = logging.getLogger(__name__) + +class NetworkMorphismSearcher: + """ + NetworkMorphismSearcher is a tuner which using network morphism techniques. + + Attributes + ---------- + n_classes : int + The class number or output node number (default: ``10``) + input_shape : tuple + A tuple including: (input_width, input_width, input_channel) + t_min : float + The minimum temperature for simulated annealing. (default: ``Constant.T_MIN``) + beta : float + The beta in acquisition function. (default: ``Constant.BETA``) + algorithm_name : str + algorithm name used in the network morphism (default: ``"Bayesian"``) + optimize_mode : str + optimize mode "minimize" or "maximize" (default: ``"minimize"``) + verbose : bool + verbose to print the log (default: ``True``) + bo : BayesianOptimizer + The optimizer used in networkmorphsim tuner. + max_model_size : int + max model size to the graph (default: ``Constant.MAX_MODEL_SIZE``) + default_model_len : int + default model length (default: ``Constant.MODEL_LEN``) + default_model_width : int + default model width (default: ``Constant.MODEL_WIDTH``) + search_space : dict + """ + + def __init__( + self, + path, + best_selected_space_path, + task="cv", + input_width=32, + input_channel=3, + n_output_node=10, + algorithm_name="Bayesian", + optimize_mode="maximize", + verbose=True, + beta=Constant.BETA, + t_min=Constant.T_MIN, + max_model_size=Constant.MAX_MODEL_SIZE, + default_model_len=Constant.MODEL_LEN, + default_model_width=Constant.MODEL_WIDTH, + ): + """ + initilizer of the NetworkMorphismSearcher. + """ + self.path = path + self.best_selected_space_path = best_selected_space_path + + if task == "cv": + self.generators = [CnnGenerator] + elif task == "common": + self.generators = [MlpGenerator] + else: + raise NotImplementedError( + '{} task not supported in List ["cv","common"]') + + self.n_classes = n_output_node + self.input_shape = (input_width, input_width, input_channel) + + self.t_min = t_min + self.beta = beta + self.algorithm_name = algorithm_name + self.optimize_mode = OptimizeMode(optimize_mode) + self.json = None + self.total_data = {} + self.verbose = verbose + + self.bo = BayesianOptimizer( + self, self.t_min, self.optimize_mode, self.beta) + + self.training_queue = [] + self.descriptors = [] + self.history = [] + + self.max_model_size = max_model_size + self.default_model_len = default_model_len + self.default_model_width = default_model_width + + def search(self, parameter_id, args): + """ + Returns a set of trial neural architecture, as a serializable object. + + Parameters + ---------- + parameter_id : int + """ + if not self.history: + self.init_search(args) + + new_father_id = None + generated_graph = None + # 先看training queue里面有没有元素,有就是init_search了 + # 如果有history了话那就生成一个 + if not self.training_queue: + new_father_id, generated_graph = self.generate() + new_model_id = args.trial_id + self.training_queue.append( + (generated_graph, new_father_id, new_model_id)) + + self.descriptors.append(generated_graph.extract_descriptor()) + + graph, father_id, model_id = self.training_queue.pop(0) + + # from graph to json + json_out = graph_to_json(graph, os.path.join(self.path, str(model_id),'model_selected_space.json')) + self.total_data[parameter_id] = (json_out, father_id, model_id) + + return json_out + + def update_searcher(self, parameter_id, value, **kwargs): + """ + Record an observation of the objective function. + + Parameters + ---------- + parameter_id : int + the id of a group of paramters that generated by nni manager. + value : dict/float + if value is dict, it should have "default" key. + """ + + if parameter_id not in self.total_data: + raise RuntimeError("Received parameter_id not in total_data.") + + (_, father_id, model_id) = self.total_data[parameter_id] + + graph = self.bo.searcher.load_model_by_id(model_id) + + # to use the value and graph + self.add_model(value, model_id) + self.update(father_id, graph, value, model_id) + + + def init_search(self,args): + """ + Call the generators to generate the initial architectures for the search. + """ + if self.verbose: + logger.info("Initializing search.") + for generator in self.generators: + graph = generator(self.n_classes, self.input_shape).generate( + self.default_model_len, self.default_model_width + ) + model_id = args.trial_id + self.training_queue.append((graph, -1, model_id)) + self.descriptors.append(graph.extract_descriptor()) + + if self.verbose: + logger.info("Initialization finished.") + + + def generate(self): + """ + Generate the next neural architecture. + + Returns + ------- + other_info : any object + Anything to be saved in the training queue together with the architecture. + generated_graph : Graph + An instance of Graph. + """ + generated_graph, new_father_id = self.bo.generate(self.descriptors) + if new_father_id is None: + new_father_id = 0 + generated_graph = self.generators[0]( + self.n_classes, self.input_shape + ).generate(self.default_model_len, self.default_model_width) + + return new_father_id, generated_graph + + def update(self, other_info, graph, metric_value, model_id): + """ + Update the controller with evaluation result of a neural architecture. + + Parameters + ---------- + other_info: any object + In our case it is the father ID in the search tree. + graph: Graph + An instance of Graph. The trained neural architecture. + metric_value: float + The final evaluated metric value. + model_id: int + """ + father_id = other_info + self.bo.fit([graph.extract_descriptor()], [metric_value]) + self.bo.add_child(father_id, model_id) + + def add_model(self, metric_value, model_id): + """ + Add model to the history, x_queue and y_queue + + Parameters + ---------- + metric_value : float + graph : dict + model_id : int + + Returns + ------- + model : dict + """ + if self.verbose: + logger.info("Saving model.") + + # Update best_model text file + ret = {"model_id": model_id, "metric_value": metric_value} + self.history.append(ret) + # update best selected space + if model_id == self.get_best_model_id(): + best_model_path = os.path.join(self.path, str(model_id),'model_selected_space.json') + shutil.copy(best_model_path, self.best_selected_space_path) + return ret + + + def get_best_model_id(self): + """ + Get the best model_id from history using the metric value + """ + + if self.optimize_mode is OptimizeMode.Maximize: + return max(self.history, key=lambda x: x["metric_value"])[ + "model_id"] + return min(self.history, key=lambda x: x["metric_value"])["model_id"] + + + def load_model_by_id(self, model_id): + """ + Get the model by model_id + + Parameters + ---------- + model_id : int + model index + + Returns + ------- + load_model : Graph + the model graph representation + """ + + with open(os.path.join(self.path, str(model_id), "model_selected_space.json")) as fin: + json_str = fin.read().replace("\n", "") + + load_model = json_to_graph(json_str) + return load_model + + def load_best_model(self): + """ + Get the best model by model id + + Returns + ------- + load_model : Graph + the model graph representation + """ + return self.load_model_by_id(self.get_best_model_id()) + + def get_metric_value_by_id(self, model_id): + """ + Get the model metric valud by its model_id + + Parameters + ---------- + model_id : int + model index + + Returns + ------- + float + the model metric + """ + for item in self.history: + if item["model_id"] == model_id: + return item["metric_value"] + return None diff --git a/dubhe-tadl/network_morphism/algorithm/nn.py b/dubhe-tadl/network_morphism/algorithm/nn.py new file mode 100644 index 0000000..c632d68 --- /dev/null +++ b/dubhe-tadl/network_morphism/algorithm/nn.py @@ -0,0 +1,227 @@ +from abc import abstractmethod + +from .graph import Graph +from .layers import (StubAdd, StubDense, StubDropout1d, + StubReLU, get_batch_norm_class, + get_conv_class, + get_dropout_class, + get_global_avg_pooling_class, + get_pooling_class) +from utils import Constant + + +class NetworkGenerator: + """The base class for generating a network. + It can be used to generate a CNN or Multi-Layer Perceptron. + Attributes: + n_output_node: Number of output nodes in the network. + input_shape: A tuple to represent the input shape. + """ + + def __init__(self, n_output_node, input_shape): + self.n_output_node = n_output_node + self.input_shape = input_shape + + @abstractmethod + def generate(self, model_len, model_width): + pass + + +class CnnGenerator(NetworkGenerator): + """A class to generate CNN. + Attributes: + n_dim: `len(self.input_shape) - 1` + conv: A class that represents `(n_dim-1)` dimensional convolution. + dropout: A class that represents `(n_dim-1)` dimensional dropout. + global_avg_pooling: A class that represents `(n_dim-1)` dimensional Global Average Pooling. + pooling: A class that represents `(n_dim-1)` dimensional pooling. + batch_norm: A class that represents `(n_dim-1)` dimensional batch normalization. + """ + + def __init__(self, n_output_node, input_shape): + super(CnnGenerator, self).__init__(n_output_node, input_shape) + self.n_dim = len(self.input_shape) - 1 + if len(self.input_shape) > 4: + raise ValueError("The input dimension is too high.") + if len(self.input_shape) < 2: + raise ValueError("The input dimension is too low.") + self.conv = get_conv_class(self.n_dim) + self.dropout = get_dropout_class(self.n_dim) + self.global_avg_pooling = get_global_avg_pooling_class(self.n_dim) + self.pooling = get_pooling_class(self.n_dim) + self.batch_norm = get_batch_norm_class(self.n_dim) + + def generate(self, model_len=None, model_width=None): + """Generates a CNN. + Args: + model_len: An integer. Number of convolutional layers. + model_width: An integer. Number of filters for the convolutional layers. + Returns: + An instance of the class Graph. Represents the neural architecture graph of the generated model. + """ + + if model_len is None: + model_len = Constant.MODEL_LEN + if model_width is None: + model_width = Constant.MODEL_WIDTH + pooling_len = int(model_len / 4) + graph = Graph(self.input_shape, False) + temp_input_channel = self.input_shape[-1] + output_node_id = 0 + stride = 1 + for i in range(model_len): + output_node_id = graph.add_layer(StubReLU(), output_node_id) + output_node_id = graph.add_layer( + self.batch_norm( + graph.node_list[output_node_id].shape[-1]), output_node_id + ) + output_node_id = graph.add_layer( + self.conv( + temp_input_channel, + model_width, + kernel_size=3, + stride=stride), + output_node_id, + ) + temp_input_channel = model_width + if pooling_len == 0 or ( + (i + 1) % pooling_len == 0 and i != model_len - 1): + output_node_id = graph.add_layer( + self.pooling(), output_node_id) + + output_node_id = graph.add_layer( + self.global_avg_pooling(), output_node_id) + output_node_id = graph.add_layer( + self.dropout(Constant.CONV_DROPOUT_RATE), output_node_id + ) + output_node_id = graph.add_layer( + StubDense(graph.node_list[output_node_id].shape[0], model_width), + output_node_id, + ) + output_node_id = graph.add_layer(StubReLU(), output_node_id) + graph.add_layer( + StubDense( + model_width, + self.n_output_node), + output_node_id) + return graph + +class ResNetGenerator(NetworkGenerator): + def __init__(self, n_output_node, input_shape): + super(ResNetGenerator, self).__init__(n_output_node, input_shape) + # self.layers = [2, 2, 2, 2] + self.in_planes = 64 + self.block_expansion = 1 + self.n_dim = len(self.input_shape) - 1 + if len(self.input_shape) > 4: + raise ValueError('The input dimension is too high.') + elif len(self.input_shape) < 2: + raise ValueError('The input dimension is too low.') + self.conv = get_conv_class(self.n_dim) + self.dropout = get_dropout_class(self.n_dim) + self.global_avg_pooling = get_global_avg_pooling_class(self.n_dim) + self.adaptive_avg_pooling = get_global_avg_pooling_class(self.n_dim) + self.batch_norm = get_batch_norm_class(self.n_dim) + + def generate(self, model_len=None, model_width=None): + if model_width is None: + model_width = Constant.MODEL_WIDTH + graph = Graph(self.input_shape, False) + temp_input_channel = self.input_shape[-1] + output_node_id = 0 + # output_node_id = graph.add_layer(StubReLU(), output_node_id) + output_node_id = graph.add_layer(self.conv(temp_input_channel, model_width, kernel_size=3), output_node_id) + output_node_id = graph.add_layer(self.batch_norm(model_width), output_node_id) + # output_node_id = graph.add_layer(self.pooling(kernel_size=3, stride=2, padding=1), output_node_id) + + output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 1) + model_width *= 2 + output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) + model_width *= 2 + output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) + model_width *= 2 + output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) + + output_node_id = graph.add_layer(self.global_avg_pooling(), output_node_id) + graph.add_layer(StubDense(model_width * self.block_expansion, self.n_output_node), output_node_id) + return graph + + def _make_layer(self, graph, planes, blocks, node_id, stride): + strides = [stride] + [1] * (blocks - 1) + out = node_id + for current_stride in strides: + out = self._make_block(graph, self.in_planes, planes, out, current_stride) + self.in_planes = planes * self.block_expansion + return out + + def _make_block(self, graph, in_planes, planes, node_id, stride=1): + out = graph.add_layer(self.batch_norm(in_planes), node_id) + out = graph.add_layer(StubReLU(), out) + residual_node_id = out + out = graph.add_layer(self.conv(in_planes, planes, kernel_size=3, stride=stride), out) + out = graph.add_layer(self.batch_norm(planes), out) + out = graph.add_layer(StubReLU(), out) + out = graph.add_layer(self.conv(planes, planes, kernel_size=3), out) + + residual_node_id = graph.add_layer(StubReLU(), residual_node_id) + residual_node_id = graph.add_layer(self.conv(in_planes, + planes * self.block_expansion, + kernel_size=1, + stride=stride), residual_node_id) + out = graph.add_layer(StubAdd(), (out, residual_node_id)) + return out + +class MlpGenerator(NetworkGenerator): + """A class to generate Multi-Layer Perceptron. + """ + + def __init__(self, n_output_node, input_shape): + """Initialize the instance. + Args: + n_output_node: An integer. Number of output nodes in the network. + input_shape: A tuple. Input shape of the network. If it is 1D, ensure the value is appended by a comma + in the tuple. + """ + super(MlpGenerator, self).__init__(n_output_node, input_shape) + if len(self.input_shape) > 1: + raise ValueError("The input dimension is too high.") + + def generate(self, model_len=None, model_width=None): + """Generates a Multi-Layer Perceptron. + Args: + model_len: An integer. Number of hidden layers. + model_width: An integer or a list of integers of length `model_len`. If it is a list, it represents the + number of nodes in each hidden layer. If it is an integer, all hidden layers have nodes equal to this + value. + Returns: + An instance of the class Graph. Represents the neural architecture graph of the generated model. + """ + if model_len is None: + model_len = Constant.MODEL_LEN + if model_width is None: + model_width = Constant.MODEL_WIDTH + if isinstance(model_width, list) and not len(model_width) == model_len: + raise ValueError( + "The length of 'model_width' does not match 'model_len'") + elif isinstance(model_width, int): + model_width = [model_width] * model_len + + graph = Graph(self.input_shape, False) + output_node_id = 0 + n_nodes_prev_layer = self.input_shape[0] + for width in model_width: + output_node_id = graph.add_layer( + StubDense(n_nodes_prev_layer, width), output_node_id + ) + output_node_id = graph.add_layer( + StubDropout1d(Constant.MLP_DROPOUT_RATE), output_node_id + ) + output_node_id = graph.add_layer(StubReLU(), output_node_id) + n_nodes_prev_layer = width + + graph.add_layer( + StubDense( + n_nodes_prev_layer, + self.n_output_node), + output_node_id) + return graph diff --git a/dubhe-tadl/network_morphism/datasets.py b/dubhe-tadl/network_morphism/datasets.py new file mode 100644 index 0000000..c846bc5 --- /dev/null +++ b/dubhe-tadl/network_morphism/datasets.py @@ -0,0 +1,64 @@ +import numpy as np +import torch +import torchvision.transforms as transforms +from utils import Constant + +class Cutout: + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + Returns: + Tensor: Image with n_holes of dimension length x length cut out of it. + """ + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + + for _ in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0.0 + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + +def data_transforms_cifar10(): + """ data_transforms for cifar10 dataset + """ + + cifar_mean = [0.49139968, 0.48215827, 0.44653124] + cifar_std = [0.24703233, 0.24348505, 0.26158768] + + train_transform = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(cifar_mean, cifar_std), + Cutout(n_holes=Constant.CUTOUT_HOLES, + length=int(32 * Constant.CUTOUT_RATIO)) + ] + ) + + valid_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)] + ) + return train_transform, valid_transform \ No newline at end of file diff --git a/dubhe-tadl/network_morphism/network_morphism_retrain.py b/dubhe-tadl/network_morphism/network_morphism_retrain.py new file mode 100644 index 0000000..ee53913 --- /dev/null +++ b/dubhe-tadl/network_morphism/network_morphism_retrain.py @@ -0,0 +1,57 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +from pytorch.network_morphism.network_morphism_trainer import NetworkMorphismTrainer +import argparse +from pytorch.utils import init_logger, mkdirs +import json + +logger = logging.getLogger(__name__) + +class Retrain: + def __init__(self, args): + self.args = args + + def run(self): + logger.info("Retraining the best model.") + with open(args.best_selected_space_path, 'r') as f: + json_out = json.load(f) + json_out = json.dumps(json_out) + + trainer = NetworkMorphismTrainer(json_out, self.args) + trainer.retrain() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("network_morphism_retrain") + parser.add_argument("--trial_id", type=int, default=0, help="Trial id") + parser.add_argument("--log_path", type=str, + default='./log', help="log for info") + parser.add_argument( + "--experiment_dir", type=str, default='./TADL', help="experiment level path" + ) + parser.add_argument( + "--best_selected_space_path", type=str, default='./best_selected_space.json', help="Path to best selected space" + ) + parser.add_argument( + "--result_path", type=str, default='./result.json', help="Path to result" + ) + parser.add_argument( + "--best_checkpoint_dir", type=str, default='./', help="Path to checkpoint saved" + ) + parser.add_argument( + "--data_dir", type=str, default='../data/', help="Path to dataset" + ) + parser.add_argument("--batch_size", type=int, + default=128, help="batch size") + parser.add_argument("--opt", type=str, default="SGD", help="optimizer") + parser.add_argument("--epochs", type=int, default=200, help="epoch limit") + parser.add_argument( + "--lr", type=float, default=0.001, help="learning rate" + ) + args = parser.parse_args() + mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) + init_logger(args.log_path) + retrain = Retrain(args) + retrain.run() diff --git a/dubhe-tadl/network_morphism/network_morphism_select.py b/dubhe-tadl/network_morphism/network_morphism_select.py new file mode 100644 index 0000000..0e8772d --- /dev/null +++ b/dubhe-tadl/network_morphism/network_morphism_select.py @@ -0,0 +1,22 @@ +import sys +sys.path.append('..'+ '/' + '..') +from argparse import ArgumentParser +from pytorch.selector import Selector + + +class NetworkMorphismSelector(Selector): + def __init__(self, single_candidate=True): + super().__init__(single_candidate) + + def fit(self): + pass + + +if __name__ == "__main__": + parser = ArgumentParser("NetworkMorphism select") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + + args = parser.parse_args() + select = NetworkMorphismSelector(True) + select.fit() diff --git a/dubhe-tadl/network_morphism/network_morphism_train.py b/dubhe-tadl/network_morphism/network_morphism_train.py new file mode 100644 index 0000000..b9f8c85 --- /dev/null +++ b/dubhe-tadl/network_morphism/network_morphism_train.py @@ -0,0 +1,87 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +from pytorch.network_morphism.network_morphism_trainer import NetworkMorphismTrainer +from pytorch.network_morphism.algorithm.networkmorphism_searcher import NetworkMorphismSearcher +import argparse +import pickle +from pytorch.utils import init_logger, mkdirs + +logger = logging.getLogger(__name__) + +def create_dir(path): + if os.path.exists(path): + # shutil.rmtree(path) + return path + os.makedirs(path) + return path + +class Train: + def __init__(self, args): + self.id = args.trial_id + self.trial_dir = os.path.join( + args.experiment_dir, 'train', str(args.trial_id)) + self.searcher_dir = os.path.join( + args.experiment_dir, '{}.pkl'.format(NetworkMorphismSearcher.__name__)) + self.args = args + self.searcher = None + # first trial + if not os.path.exists(self.searcher_dir): + self.searcher = NetworkMorphismSearcher(os.path.join( + args.experiment_dir, 'train'), args.best_selected_space_path) + else: + # load from previous round + with open(self.searcher_dir, 'rb') as f: + self.searcher = pickle.load(f) + + def run_trial_job(self): + logger.info('trial {} search next model'.format(self.id)) + model = self.searcher.search(self.id,self.args) + + trainer = NetworkMorphismTrainer(model, self.args) + logger.info('trial {} run training script'.format(self.id)) + metric = trainer.train() + + if metric != None: + logger.info('trial {} receive trial result'.format(self.id)) + self.searcher.update_searcher(self.id, metric) + + with open(self.searcher_dir, 'wb') as f: + pickle.dump(self.searcher, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("network_morphism") + parser.add_argument("--trial_id", type=int, default=0, help="Trial id") + parser.add_argument( + "--data_dir", type=str, default='../data/', help="Path to dataset" + ) + parser.add_argument( + "--log_path", type=str, default='./log', help="Path to log file" + ) + parser.add_argument( + "--experiment_dir", type=str, default='./TADL', help="experiment level path" + ) + parser.add_argument( + "--result_path", type=str, default='./result.json', help="trial level path to result" + ) + parser.add_argument( + "--search_space_path", type=str, default='./search_space.json', help="experiment level path to search space" + ) + parser.add_argument( + "--best_selected_space_path", type=str, default='./best_selected_space.json', help="experiment level path to best selected space" + ) + parser.add_argument("--batch_size", type=int, + default=128, help="batch size") + parser.add_argument("--opt", type=str, default="SGD", help="optimizer") + parser.add_argument("--epochs", type=int, default=2, help="epoch limit") + parser.add_argument( + "--lr", type=float, default=0.001, help="learning rate" + ) + args = parser.parse_args() + mkdirs(args.experiment_dir, args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) + create_dir(os.path.join(args.experiment_dir,'train',str(args.trial_id))) + init_logger(args.log_path) + train = Train(args) + train.run_trial_job() diff --git a/dubhe-tadl/network_morphism/network_morphism_trainer.py b/dubhe-tadl/network_morphism/network_morphism_trainer.py new file mode 100644 index 0000000..69a7bc7 --- /dev/null +++ b/dubhe-tadl/network_morphism/network_morphism_trainer.py @@ -0,0 +1,175 @@ +import logging +from algorithm.graph import json_to_graph + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.utils.data as data +import torchvision +import re + +import datasets +from utils import Constant, EarlyStop, save_json_result +from pytorch.utils import save_best_checkpoint + +# pylint: disable=W0603 +# set the logger format +logger = logging.getLogger(__name__) + +class NetworkMorphismTrainer: + def __init__(self, model_json, args): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.batch_size = args.batch_size + self.epochs = args.epochs + self.lr = args.lr + self.optimizer_name = args.opt + self.data_dir = args.data_dir + self.trial_id = args.trial_id + self.args = args + + # Loading Data + logger.info("Preparing data..") + + transform_train, transform_test = datasets.data_transforms_cifar10() + + trainset = torchvision.datasets.CIFAR10( + root=self.data_dir, train=True, download=True, transform=transform_train + ) + self.trainloader = data.DataLoader( + trainset, batch_size=self.batch_size, shuffle=True, num_workers=1 + ) + + testset = torchvision.datasets.CIFAR10( + root=self.data_dir, train=False, download=True, transform=transform_test + ) + self.testloader = data.DataLoader( + testset, batch_size=self.batch_size, shuffle=False, num_workers=1 + ) + + # Model + logger.info("Building model..") + # build model from json representation + self.graph = json_to_graph(model_json) + + self.net = self.graph.produce_torch_model() + + if self.device == "cuda" and torch.cuda.device_count() > 1: + self.net = nn.DataParallel(self.net) + self.net.to(self.device) + + self.criterion = nn.CrossEntropyLoss() + if self.optimizer_name == "SGD": + self.optimizer = optim.SGD( + self.net.parameters(), lr=self.lr, momentum=0.9, weight_decay=3e-4 + ) + if self.optimizer_name == "Adadelta": + self.optimizer = optim.Adadelta(self.net.parameters(), lr=self.lr) + if self.optimizer_name == "Adagrad": + self.optimizer = optim.Adagrad(self.net.parameters(), lr=self.lr) + if self.optimizer_name == "Adam": + self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) + if self.optimizer_name == "Adamax": + self.optimizer = optim.Adamax(self.net.parameters(), lr=self.lr) + if self.optimizer_name == "RMSprop": + self.optimizer = optim.RMSprop(self.net.parameters(), lr=self.lr) + + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, self.epochs) + + def train_one_epoch(self): + """ + train model on each epoch in trainset + """ + self.net.train() + + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + self.optimizer.zero_grad() + outputs = self.net(inputs) + loss = self.criterion(outputs, targets) + loss.backward() + self.optimizer.step() + + def validate_one_epoch(self, epoch): + """ eval model on each epoch in testset + """ + self.net.eval() + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(self.testloader): + inputs, targets = inputs.to( + self.device), targets.to(self.device) + outputs = self.net(inputs) + loss = self.criterion(outputs, targets) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + acc = correct / total + logger.info("Epoch: %d, accuracy: %.3f", epoch, acc) + result = {"type": "Accuracy", "result": { + "sequence": epoch, "category": "epoch", "value": acc}} + + save_json_result(self.args.result_path, result) + return test_loss, acc + + def train(self): + try: + max_no_improvement_num = Constant.MAX_NO_IMPROVEMENT_NUM + early_stop = EarlyStop(max_no_improvement_num) + early_stop.on_train_begin() + test_metric_value_list = [] + + for ep in range(self.epochs): + self.train_one_epoch() + test_loss, test_acc = self.validate_one_epoch(ep) + self.scheduler.step() + test_metric_value_list.append(test_acc) + decreasing = early_stop.on_epoch_end(test_loss) + if not decreasing: + break + + last_num = min(max_no_improvement_num, self.epochs) + estimated_performance = sum( + test_metric_value_list[-last_num:]) / last_num + + logger.info("final accuracy: %.3f", estimated_performance) + + + except RuntimeError as e: + if not re.search('out of memory', str(e)): + raise e + print( + '\nCurrent model size is too big. Discontinuing training this model to search for other models.') + Constant.MAX_MODEL_SIZE = self.graph.size()-1 + return None + except Exception as e: + logger.exception(e) + raise + + return estimated_performance + + def retrain(self): + logger.info("here") + try: + best_acc = 0.0 + for ep in range(self.epochs): + logger.info(ep) + self.train_one_epoch() + _, test_acc = self.validate_one_epoch(ep) + self.scheduler.step() + if test_acc > best_acc: + best_acc = test_acc + save_best_checkpoint(self.args.best_checkpoint_dir, + self.net, self.optimizer, self.epochs) + + logger.info("final accuracy: %.3f", best_acc) + + + except Exception as exception: + logger.exception(exception) + raise diff --git a/dubhe-tadl/network_morphism/utils.py b/dubhe-tadl/network_morphism/utils.py new file mode 100644 index 0000000..e1f5cb0 --- /dev/null +++ b/dubhe-tadl/network_morphism/utils.py @@ -0,0 +1,102 @@ +from enum import Enum +import json + +class Constant: + # Data + CUTOUT_HOLES = 1 + CUTOUT_RATIO = 0.5 + + # Searcher + MAX_MODEL_NUM = 1000 + MAX_LAYERS = 200 + N_NEIGHBOURS = 8 + MAX_MODEL_SIZE = (1 << 25) + MAX_LAYER_WIDTH = 4096 + KERNEL_LAMBDA = 1.0 + BETA = 2.576 + T_MIN = 0.0001 + + + MLP_MODEL_LEN = 3 + MLP_MODEL_WIDTH = 5 + MODEL_LEN = 3 + MODEL_WIDTH = 64 + POOLING_KERNEL_SIZE = 2 + DENSE_DROPOUT_RATE = 0.5 + CONV_DROPOUT_RATE = 0.25 + MLP_DROPOUT_RATE = 0.25 + CONV_BLOCK_DISTANCE = 2 + + # trainer + MAX_NO_IMPROVEMENT_NUM = 5 + MIN_LOSS_DEC = 1e-4 + + +class OptimizeMode(Enum): + """Optimize Mode class + + if OptimizeMode is 'minimize', it means the tuner need to minimize the reward + that received from Trial. + + if OptimizeMode is 'maximize', it means the tuner need to maximize the reward + that received from Trial. + """ + Minimize = 'minimize' + Maximize = 'maximize' + +class EarlyStop: + """A class check for early stop condition. + Attributes: + training_losses: Record all the training loss. + minimum_loss: The minimum loss we achieve so far. Used to compared to determine no improvement condition. + no_improvement_count: Current no improvement count. + _max_no_improvement_num: The maximum number specified. + _done: Whether condition met. + _min_loss_dec: A threshold for loss improvement. + """ + + def __init__(self, max_no_improvement_num=None, min_loss_dec=None): + self.training_losses = [] + self.minimum_loss = None + self.no_improvement_count = 0 + self._max_no_improvement_num = max_no_improvement_num if max_no_improvement_num is not None \ + else Constant.MAX_NO_IMPROVEMENT_NUM + self._done = False + self._min_loss_dec = min_loss_dec if min_loss_dec is not None else Constant.MIN_LOSS_DEC + + def on_train_begin(self): + """Initiate the early stop condition. + Call on every time the training iteration begins. + """ + self.training_losses = [] + self.no_improvement_count = 0 + self._done = False + self.minimum_loss = float('inf') + + def on_epoch_end(self, loss): + """Check the early stop condition. + Call on every time the training iteration end. + Args: + loss: The loss function achieved by the epoch. + Returns: + True if condition met, otherwise False. + """ + self.training_losses.append(loss) + if self._done and loss > (self.minimum_loss - self._min_loss_dec): + return False + + if loss > (self.minimum_loss - self._min_loss_dec): + self.no_improvement_count += 1 + else: + self.no_improvement_count = 0 + self.minimum_loss = loss + + if self.no_improvement_count > self._max_no_improvement_num: + self._done = True + + return True + +def save_json_result(path, data): + with open(path,'a') as f: + json.dump(data,f) + f.write('\n') diff --git a/dubhe-tadl/pcdarts/__init__.py b/dubhe-tadl/pcdarts/__init__.py new file mode 100644 index 0000000..ae853e4 --- /dev/null +++ b/dubhe-tadl/pcdarts/__init__.py @@ -0,0 +1 @@ +from pytorch.pcdarts.pcdartsmutator import PCdartsMutator diff --git a/dubhe-tadl/pcdarts/model.py b/dubhe-tadl/pcdarts/model.py new file mode 100644 index 0000000..0c50bd3 --- /dev/null +++ b/dubhe-tadl/pcdarts/model.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import OrderedDict + +import torch +import torch.nn as nn + +from pytorch import mutables +from pytorch.darts import ops + +def random_channel_shuffle(x): + num_channels = x.data.size()[1] + indices = torch.randperm(num_channels) + x = x[:, indices] + return x + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, + channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + +class AuxiliaryHead(nn.Module): + """ Auxiliary head in 2/3 place of network to let the gradient flow well """ + + def __init__(self, input_size, C, n_classes): + """ assuming input size 7x7 or 8x8 """ + assert input_size in [7, 8] + super().__init__() + self.net = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out + nn.Conv2d(C, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.linear = nn.Linear(768, n_classes) + + def forward(self, x): + out = self.net(x) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, k, num_downsample_connect, search): + super().__init__() + if search: + self.k = k + partial_channles = channels // k + else: + partial_channles = channels + + self.search = search + self.ops = nn.ModuleList() + choice_keys = [] + for i in range(num_prev_nodes): + stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) + self.ops.append( + mutables.LayerChoice(OrderedDict([ + ("maxpool", ops.PoolBN('max', partial_channles, 3, stride, 1, affine=False)), + ("avgpool", ops.PoolBN('avg', partial_channles, 3, stride, 1, affine=False)), + ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(partial_channles, partial_channles, affine=False)), + ("sepconv3x3", ops.SepConv(partial_channles, partial_channles, 3, stride, 1, affine=False)), + ("sepconv5x5", ops.SepConv(partial_channles, partial_channles, 5, stride, 2, affine=False)), + ("dilconv3x3", ops.DilConv(partial_channles, partial_channles, 3, stride, 2, 2, affine=False)), + ("dilconv5x5", ops.DilConv(partial_channles, partial_channles, 5, stride, 4, 2, affine=False)) + ]), key=choice_keys[-1])) + self.drop_path = ops.DropPath() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) + self.pool = nn.MaxPool2d(2,2) + + def forward(self, prev_nodes): + assert len(self.ops) == len(prev_nodes), "len(self.ops) != len(prev_nodes) in Node" + # for each candicate predecessor of each intermediate node + if self.search: + # in search + results = [] + for op, x in zip(self.ops, prev_nodes): + # channel shuffle + channels = x.shape[1] + # channel proportion k=4 + temp0 = x[ : , : channels//self.k, : , :] + temp1 = x[ : ,channels//self.k : , : , :] + out = op(temp0) + # normal + if out.shape[2] == x.shape[2]: + result = torch.cat([out, temp1], dim=1) + # reduction + else: + result = torch.cat([out, self.pool(temp1)], dim=1) + results.append(channel_shuffle(result, self.k)) + + # # channel random shuffule + # channels = random_channel_shuffle(x).shape[1] + # # channel proportion k=4 + # temp0 = x[ : , : channels//self.k, : , :] + # temp1 = x[ : ,channels//self.k : , : , :] + # out = op(temp0) + # # normal + # if out.shape[2] == x.shape[2]: + # result = torch.cat([out, temp1], dim=1) + # # reduction + # else: + # result = torch.cat([out, self.pool(temp1)], dim=1) + # results.append(result) + else: + # in retrain, no channel shuffle + results = [op(node) for op, node in zip(self.ops, prev_nodes)] + + output = [self.drop_path(re) if re is not None else None for re in results] + return self.input_switch(output) + + +class Cell(nn.Module): + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, k, search): + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, k, 2 if reduction else 0, search)) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for node in self.mutable_ops: + cur_tensor = node(tensors) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class CNN(nn.Module): + + def __init__(self, input_size, in_channels, channels, n_classes, n_layers, k=4, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True): + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + self.n_nodes = n_nodes + self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, k, search) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + if i == self.aux_pos: + self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + aux_logits = None + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1) + if i == self.aux_pos and self.training: + aux_logits = self.aux_head(s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + + if aux_logits is not None: + return logits, aux_logits + return logits + + def drop_path_prob(self, p): + for module in self.modules(): + if isinstance(module, ops.DropPath): + module.p = p + + def _loss(self, input, target): + logits = self(input) + return self._criterion(logits, target) diff --git a/dubhe-tadl/pcdarts/pcdarts_retrain.py b/dubhe-tadl/pcdarts/pcdarts_retrain.py new file mode 100644 index 0000000..9687596 --- /dev/null +++ b/dubhe-tadl/pcdarts/pcdarts_retrain.py @@ -0,0 +1,204 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import numpy as np +# from torch.utils.tensorboard import SummaryWriter +import torch.backends.cudnn as cudnn + +from model import CNN +from pytorch.fixed import apply_fixed_architecture +from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter +from pytorch.darts import utils +from pytorch.darts import datasets +from pytorch.retrainer import Retrainer + +logger = logging.getLogger(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# writer = SummaryWriter() + +class PCdartsRetrainer(Retrainer): + def __init__(self, aux_weight, grad_clip, epochs, log_frequency): + self.aux_weight = aux_weight + self.grad_clip = grad_clip + self.epochs = epochs + self.log_frequency = log_frequency + + def train(self, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]["lr"] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + # writer.add_scalar("lr", cur_lr, global_step=cur_step) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if self.aux_weight > 0.: + loss += self.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + # writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) + # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) + + if step % self.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + + def validate(self, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % self.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + # writer.add_scalar("loss/test", losses.avg, global_step=cur_step) + # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) + # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + return top1.avg + +if __name__ == "__main__": + parser = ArgumentParser("PCDARTS retrain") + parser.add_argument("--data_dir", type=str, + default='./', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='./result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='.0/log', help="log for info") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument("--best_checkpoint_dir", type=str, + default='', help="default name is best_checkpoint_epoch{}.pth") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--lr", default=0.01, type=float) + parser.add_argument("--batch_size", default=96, type=int) + parser.add_argument("--log_frequency", default=10, type=int) + parser.add_argument("--epochs", default=600, type=int) + parser.add_argument("--aux_weight", default=0.4, type=float) + parser.add_argument("--drop_path_prob", default=0.2, type=float) + parser.add_argument("--workers", default=4, type=int) + parser.add_argument("--class_num", default=10, type=int, help="cifar10") + parser.add_argument("--channels", default=36, type=int) + parser.add_argument("--grad_clip", default=6., type=float) + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) + init_logger(args.log_path) + logger.info(args) + set_seed(args.trial_id) + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) + + model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True, search=False) + apply_fixed_architecture(model, args.best_selected_space_path) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + + retrainer = PCdartsRetrainer(aux_weight=args.aux_weight, + grad_clip=args.grad_clip, + epochs=args.epochs, + log_frequency = args.log_frequency) + # result = {"Accuracy": [], "Cost_time": ''} + best_top1 = 0. + start_time = time.time() + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + retrainer.train(train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) + # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} + logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') + # result["Accuracy"].append(top1) + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) + cost_time = time.time() - start_time + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) + + # result["Cost_time"] = str(cost_time) + ' s' + # dump_global_result(args.result_path, result) + save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) + logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) \ No newline at end of file diff --git a/dubhe-tadl/pcdarts/pcdarts_select.py b/dubhe-tadl/pcdarts/pcdarts_select.py new file mode 100644 index 0000000..730e174 --- /dev/null +++ b/dubhe-tadl/pcdarts/pcdarts_select.py @@ -0,0 +1,21 @@ +import sys +sys.path.append('../..') +from pytorch.selector import Selector +from argparse import ArgumentParser + + +class PCdartsSelector(Selector): + def __init__(self, single_candidate=True): + super().__init__(single_candidate) + + def fit(self): + pass + +if __name__ == "__main__": + parser = ArgumentParser("DARTS select") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + + args = parser.parse_args() + darts_selector = PCdartsSelector(True) + darts_selector.fit() \ No newline at end of file diff --git a/dubhe-tadl/pcdarts/pcdarts_train.py b/dubhe-tadl/pcdarts/pcdarts_train.py new file mode 100644 index 0000000..55c8418 --- /dev/null +++ b/dubhe-tadl/pcdarts/pcdarts_train.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import sys +sys.path.append('..'+ '/' + '..') +import time +from argparse import ArgumentParser + +from model import CNN +import torch +import torch.nn as nn + +from pytorch.callbacks import BestArchitectureCheckpoint, LRSchedulerCallback +from pytorch.pcdarts import PCdartsMutator +from pytorch.darts import DartsTrainer +from pytorch.darts.utils import accuracy +from pytorch.darts import datasets +from pytorch.utils import * + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + parser = ArgumentParser("PCDARTS train") + parser.add_argument("--data_dir", type=str, + default='../data/', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='.0/result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='.0/log', help="log for info") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search space of PDARTS") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument('--model_lr', type=float, default=0.1, help='learning rate for training model weights') + parser.add_argument('--arch_lr', type=float, default=6e-4, help='learning rate for training architecture') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--layers", default=8, type=int) + parser.add_argument("--channels", default=16, type=int) + parser.add_argument("--batch_size", default=96, type=int) + parser.add_argument("--log_frequency", default=50, type=int) + parser.add_argument("--class_num", default=10, type=int, help="cifar10") + parser.add_argument("--epochs", default=5, type=int) + parser.add_argument("--pre_epochs", default=15, type=int, help='pre epochs to train weight only') + parser.add_argument("--k", default=4, type=int, help="channel portion of channel shuffle") + parser.add_argument("--unrolled", default=False, action="store_true") + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) + init_logger(args.log_path, "info") + logger.info(args) + set_seed(args.trial_id) + + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10", root=args.data_dir) + + model = CNN(32, 3, args.channels, args.class_num, args.layers, n_nodes=args.nodes, k=args.k) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), args.model_lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + logger.info("initializing trainer") + trainer = DartsTrainer(model, + loss=criterion, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + optimizer=optim, + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + mutator=PCdartsMutator(model), + batch_size=args.batch_size, + log_frequency=args.log_frequency, + arch_lr=args.arch_lr, + unrolled=args.unrolled, + result_path=args.result_path, + num_pre_epochs=args.pre_epochs, + search_space_path=args.search_space_path, + callbacks= + [LRSchedulerCallback(lr_scheduler), BestArchitectureCheckpoint(args.best_selected_space_path, args.epochs)]) + + logger.info("training") + t1 = time.time() + trainer.train() + # res_json = trainer.result + cost_time = time.time() - t1 + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) + + # res_json["Cost_time"] = str(cost_time) + ' s' + # dump_global_result(args.result_path, res_json) \ No newline at end of file diff --git a/dubhe-tadl/pcdarts/pcdartsmutator.py b/dubhe-tadl/pcdarts/pcdartsmutator.py new file mode 100644 index 0000000..8c81b47 --- /dev/null +++ b/dubhe-tadl/pcdarts/pcdartsmutator.py @@ -0,0 +1,146 @@ + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from pytorch.mutator import Mutator +from pytorch.mutables import LayerChoice, InputChoice + +_logger = logging.getLogger(__name__) + +class PCdartsMutator(Mutator): + + """ + Connects the model in a PC-DARTS (differentiable) way. + + Two connections are automatically inserted for each LayerChoice and InputChoice, when these connections are selected by softmax function. + Ops on the LayerChoice are selected by max top-k probabilities. But channels in the all candicate predecessors on the InputChoice are weighted sum + There is no op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false`` + (not chosen). + + All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based + on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice + will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0. + + It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the + value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will + be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully. + + Attributes + ---------- + choices: ParameterDict + dict that maps keys of LayerChoices to weighted-connection float tensors. + """ + def __init__(self, model): + super().__init__(model) + self.choices = nn.ParameterDict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) + if isinstance(mutable, InputChoice): + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.n_candidates)) + + def device(self): + for v in self.choices.values(): + return v.device + + def sample_search(self): + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] + elif isinstance(mutable, InputChoice): + result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1) + return result + + def sample_final(self): + result = dict() + edges_max = dict() + choices = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + # multiply the normalized coefficients together to select top-1 op in each LayerChoice + predecessor_idx = int(mutable.key[-1]) + inputchoice_key = mutable.key[:-2] + "switch" + choices[mutable.key] = self.choices[mutable.key] * self.choices[inputchoice_key][predecessor_idx] + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + # select non-none top-1 op + max_val, index = torch.max(F.softmax(choices[mutable.key], dim=-1)[:-1], 0) + edges_max[mutable.key] = max_val + result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool() + for mutable in self.mutables: + if isinstance(mutable, InputChoice): + if mutable.n_chosen is not None: + weights = [] + for src_key in mutable.choose_from: + if src_key not in edges_max: + _logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key) + weights.append(edges_max.get(src_key, 0.)) + weights = torch.tensor(weights) # pylint: disable=not-callable + # select top-2 strongest predecessor + _, topk_edge_indices = torch.topk(weights, mutable.n_chosen) + selected_multihot = [] + for i, src_key in enumerate(mutable.choose_from): + if i not in topk_edge_indices and src_key in result: + # If an edge is never selected, there is no need to calculate any op on this edge. + # This is to eliminate redundant calculation. + result[src_key] = torch.zeros_like(result[src_key]) + selected_multihot.append(i in topk_edge_indices) + result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable + else: + result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable + return result + + def _generate_search_space(self): + """ + Generate search space from mutables. + Here is the search space format: + :: + { key_name: {"_type": "layer_choice", + "_value": ["conv1", "conv2"]} } + { key_name: {"_type": "input_choice", + "_value": {"candidates": ["in1", "in2"], + "n_chosen": 1}} } + Returns + ------- + dict + the generated search space + """ + res = OrderedDict() + res["op_list"] = OrderedDict() + res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()} + keys = [] + for mutable in self.mutables: + # for now we only generate flattened search space + if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36: + break + + if isinstance(mutable, LayerChoice): + key = mutable.key + if key not in keys: + val = mutable.names + if not res["op_list"]: + res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} + node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][node_type][key] = "op_list" + keys.append(key) + + elif isinstance(mutable, InputChoice): + key = mutable.key + if key not in keys: + node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][node_type][key] = {"_type": "input_choice", + "_value": {"candidates": mutable.choose_from, + "n_chosen": mutable.n_chosen}} + keys.append(key) + else: + raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) + + return res + diff --git a/dubhe-tadl/pcdarts/reademe.md b/dubhe-tadl/pcdarts/reademe.md new file mode 100644 index 0000000..c58b3d1 --- /dev/null +++ b/dubhe-tadl/pcdarts/reademe.md @@ -0,0 +1,81 @@ +# train stage +`python pcdarts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --layers 5 --model_lr 0.025 --arch_lr 3e-4 --epochs 2 --pre_epochs 1 --batch_size 64 --channels 16` + +# select stage +`python pcdarts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json' ` + +# retrain stage +`python pcdarts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 2 --lr 0.01 --layers 20 --channels 36` + +# output file +`result.json` +``` +{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} +{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} +{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} +``` + +`search_space.json` +``` +{ + "op_list": { + "_type": "layer_choice", + "_value": [ + "maxpool", + "avgpool", + "skipconnect", + "sepconv3x3", + "sepconv5x5", + "dilconv3x3", + "dilconv5x5", + "none" + ] + }, + "search_space": { + "normal_n2_p0": "op_list", + "normal_n2_p1": "op_list", + "normal_n2_switch": { + "_type": "input_choice", + "_value": { + "candidates": [ + "normal_n2_p0", + "normal_n2_p1" + ], + "n_chosen": 2 + } + }, + + ... + } +``` + +`best_selected_space.json` +``` +{ + "normal_n2_p0": "dilconv5x5", + "normal_n2_p1": "dilconv5x5", + "normal_n2_switch": [ + "normal_n2_p0", + "normal_n2_p1" + ], + "normal_n3_p0": "sepconv3x3", + "normal_n3_p1": "dilconv5x5", + "normal_n3_p2": [], + "normal_n3_switch": [ + "normal_n3_p0", + "normal_n3_p1" + ], + "normal_n4_p0": [], + "normal_n4_p1": "dilconv5x5", + "normal_n4_p2": "sepconv5x5", + "normal_n4_p3": [], + "normal_n4_switch": [ + "normal_n4_p1", + "normal_n4_p2" + ], + ... +} +``` \ No newline at end of file diff --git a/dubhe-tadl/pdarts/__init__.py b/dubhe-tadl/pdarts/__init__.py new file mode 100644 index 0000000..f6ce939 --- /dev/null +++ b/dubhe-tadl/pdarts/__init__.py @@ -0,0 +1,2 @@ +from pytorch.pdarts.pdartsmutator import PdartsMutator +from pytorch.pdarts.pdartstrainer import PdartsTrainer \ No newline at end of file diff --git a/dubhe-tadl/pdarts/model.py b/dubhe-tadl/pdarts/model.py new file mode 100644 index 0000000..46a726c --- /dev/null +++ b/dubhe-tadl/pdarts/model.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from collections import OrderedDict +import torch +import torch.nn as nn +from pytorch import mutables +from pytorch.darts import ops + +class AuxiliaryHead(nn.Module): + """ Auxiliary head in 2/3 place of network to let the gradient flow well """ + + def __init__(self, input_size, C, n_classes): + """ assuming input size 7x7 or 8x8 """ + assert input_size in [7, 8] + super().__init__() + self.net = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out + nn.Conv2d(C, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.linear = nn.Linear(768, n_classes) + + def forward(self, x): + out = self.net(x) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits + + +class Node(nn.Module): + def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, search, dropout_rate): + super().__init__() + self.dropout_rate = dropout_rate + self.ops = nn.ModuleList() + choice_keys = [] + for i in range(num_prev_nodes): + stride = 2 if i < num_downsample_connect else 1 + choice_keys.append("{}_p{}".format(node_id, i)) + skip_op = nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False) + # In search, op-level dropout for skip-connect + if search and self.dropout_rate > 0: + skip_op = nn.Sequential(skip_op, nn.Dropout(self.dropout_rate)) + self.ops.append( + mutables.LayerChoice(OrderedDict([ + ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), + ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), + ("skipconnect", skip_op), + ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)), + ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), + ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), + ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) + ]), key=choice_keys[-1])) + # In retrain, DropPath for non skip-connect, p in DropPath default to 0 + self.drop_path = ops.DropPath() + self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) + + def forward(self, prev_nodes): + assert len(self.ops) == len(prev_nodes) + output = [] + for op, node in zip(self.ops, prev_nodes): + out = op(node) + # In retrain + if out is not None: + if not isinstance(op, nn.Identity): + out = self.drop_path(out) + else: + out = None + output.append(out) + # out = [op(node) for op, node in zip(self.ops, prev_nodes)] + # out = [self.drop_path(o) if o is not None else None for o in out] + return self.input_switch(output) + + +class Cell(nn.Module): + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, search, dropout_rate): + super().__init__() + self.reduction = reduction + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(2, self.n_nodes + 2): + self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, 2 if reduction else 0, search, dropout_rate)) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for node in self.mutable_ops: + cur_tensor = node(tensors) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output + + +class CNN(nn.Module): + + def __init__(self, input_size, in_channels, channels, n_classes, n_layers, dropout_rate, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True): + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, search, dropout_rate) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + if i == self.aux_pos: + self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + aux_logits = None + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1) + if i == self.aux_pos and self.training: + aux_logits = self.aux_head(s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + + if aux_logits is not None: + return logits, aux_logits + return logits + + def drop_path_prob(self, p, search=True): + if search: + for module in self.modules(): + # In search, update dropout rate + if isinstance(module, nn.Sequential) and isinstance(module[0], nn.Identity): + module[1].dropout_rate = p + else: + # In retrain, update ops.DropPath + for module in self.modules(): + if isinstance(module, ops.DropPath): + module.p = p \ No newline at end of file diff --git a/dubhe-tadl/pdarts/pdarts_retrain.py b/dubhe-tadl/pdarts/pdarts_retrain.py new file mode 100644 index 0000000..d2669a2 --- /dev/null +++ b/dubhe-tadl/pdarts/pdarts_retrain.py @@ -0,0 +1,203 @@ +import sys +sys.path.append('..'+ '/' + '..') +import os +import logging +import time +import json +from argparse import ArgumentParser + +import torch +import torch.nn as nn +# from torch.utils.tensorboard import SummaryWriter + +from model import CNN +from pytorch.fixed import apply_fixed_architecture +from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter +from pytorch.darts import utils +from pytorch.darts import datasets +from pytorch.retrainer import Retrainer + +logger = logging.getLogger(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# writer = SummaryWriter() + +class PdartsRetrainer(Retrainer): + def __init__(self, aux_weight, grad_clip, epochs, log_frequency): + self.aux_weight = aux_weight + self.grad_clip = grad_clip + self.epochs = epochs + self.log_frequency = log_frequency + + def train(self, train_loader, model, optimizer, criterion, epoch): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + cur_step = epoch * len(train_loader) + cur_lr = optimizer.param_groups[0]["lr"] + logger.info("Epoch %d LR %.6f", epoch, cur_lr) + # writer.add_scalar("lr", cur_lr, global_step=cur_step) + + model.train() + + for step, (x, y) in enumerate(train_loader): + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = x.size(0) + + optimizer.zero_grad() + logits, aux_logits = model(x) + loss = criterion(logits, y) + if self.aux_weight > 0.: + loss += self.aux_weight * criterion(aux_logits, y) + loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) + optimizer.step() + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + # writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) + # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) + + if step % self.log_frequency == 0 or step == len(train_loader) - 1: + logger.info( + "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + + def validate(self, valid_loader, model, criterion, epoch, cur_step): + top1 = AverageMeter("top1") + top5 = AverageMeter("top5") + losses = AverageMeter("losses") + + model.eval() + + with torch.no_grad(): + for step, (X, y) in enumerate(valid_loader): + X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) + bs = X.size(0) + + logits = model(X) + loss = criterion(logits, y) + + accuracy = utils.accuracy(logits, y, topk=(1, 5)) + losses.update(loss.item(), bs) + top1.update(accuracy["acc1"], bs) + top5.update(accuracy["acc5"], bs) + + if step % self.log_frequency == 0 or step == len(valid_loader) - 1: + logger.info( + "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, + top1=top1, top5=top5)) + + # writer.add_scalar("loss/test", losses.avg, global_step=cur_step) + # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) + # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) + + logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) + + return top1.avg + + +if __name__ == "__main__": + parser = ArgumentParser("Pdarts retrain") + parser.add_argument("--data_dir", type=str, + default='./', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='./result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='.0/log', help="log for info") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument("--best_checkpoint_dir", type=str, + default='', help="default name is best_checkpoint_epoch{}.pth") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--layers", default=20, type=int) + parser.add_argument("--batch_size", default=96, type=int) + parser.add_argument("--log_frequency", default=10, type=int) + parser.add_argument("--epochs", default=600, type=int) + parser.add_argument("--lr", default=0.025, type=float) + parser.add_argument("--channels", default=36, type=int) + parser.add_argument("--aux_weight", default=0.4, type=float) + parser.add_argument("--drop_path_prob", default=0.3, type=float) + parser.add_argument("--workers", default=4) + parser.add_argument("--grad_clip", default=5., type=float) + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) + init_logger(args.log_path) + logger.info(args) + set_seed(args.trial_id) + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) + + model = CNN(32, 3, 36, 10, args.layers, auxiliary=True, search=False, dropout_rate=0.0) + if isinstance(args.best_selected_space_path, str): + with open(args.best_selected_space_path) as f: + fixed_arc = json.load(f) + apply_fixed_architecture(model, fixed_arc=fixed_arc["best_selected_space"]) + criterion = nn.CrossEntropyLoss() + + model.to(device) + criterion.to(device) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) + + train_loader = torch.utils.data.DataLoader(dataset_train, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True) + valid_loader = torch.utils.data.DataLoader(dataset_valid, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=True) + retrainer = PdartsRetrainer(aux_weight=args.aux_weight, + grad_clip=args.grad_clip, + epochs=args.epochs, + log_frequency = args.log_frequency) + best_top1 = 0. + start_time = time.time() + for epoch in range(args.epochs): + drop_prob = args.drop_path_prob * epoch / args.epochs + model.drop_path_prob(drop_prob) + + # training + retrainer.train(train_loader, model, optimizer, criterion, epoch) + + # validation + cur_step = (epoch + 1) * len(train_loader) + top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) + # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} + logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') + best_top1 = max(best_top1, top1) + + lr_scheduler.step() + + logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) + cost_time = time.time() - start_time + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) + + # result["Cost_time"] = str(cost_time) + ' s' + # dump_global_result(args.result_path, result) + save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) + logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) diff --git a/dubhe-tadl/pdarts/pdarts_select.py b/dubhe-tadl/pdarts/pdarts_select.py new file mode 100644 index 0000000..55b13d2 --- /dev/null +++ b/dubhe-tadl/pdarts/pdarts_select.py @@ -0,0 +1,22 @@ +import sys +sys.path.append('../..') +from pytorch.selector import Selector +from argparse import ArgumentParser + + +class PdartsSelector(Selector): + def __init__(self, single_candidate=True): + super().__init__(single_candidate) + + def fit(self): + pass + +if __name__ == "__main__": + parser = ArgumentParser("PDARTS select") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + + args = parser.parse_args() + darts_selector = PdartsSelector(True) + darts_selector.fit() + \ No newline at end of file diff --git a/dubhe-tadl/pdarts/pdarts_train.py b/dubhe-tadl/pdarts/pdarts_train.py new file mode 100644 index 0000000..eaee95f --- /dev/null +++ b/dubhe-tadl/pdarts/pdarts_train.py @@ -0,0 +1,79 @@ +import sys +sys.path.append('..'+ '/' + '..') +import time +import logging +from argparse import ArgumentParser +from pdartstrainer import PdartsTrainer +from pytorch.utils import mkdirs, set_seed, init_logger, list_str2int + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + parser = ArgumentParser("pdarts") + parser.add_argument("--data_dir", type=str, + default='../data/', help="search_space json file") + parser.add_argument("--result_path", type=str, + default='0/result.json', help="training result") + parser.add_argument("--log_path", type=str, + default='0/log', help="log for info") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search space of PDARTS") + parser.add_argument("--best_selected_space_path", type=str, + default='./best_selected_space.json', help="final best selected space") + parser.add_argument('--trial_id', type=int, default=0, help='for ensuring reproducibility ') + parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights') + parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture') + parser.add_argument("--epochs", default=2, type=int) + parser.add_argument("--pre_epochs", default=15, type=int) + parser.add_argument("--batch_size", default=96, type=int) + parser.add_argument("--init_layers", default=5, type=int) + parser.add_argument('--add_layers', default=[0, 6, 12], nargs='+', type=int, help='add layers in each stage') + parser.add_argument('--dropped_ops', default=[3, 2, 1], nargs='+', type=int, help='drop ops in each stage') + parser.add_argument('--dropout_rates', default=[0.1, 0.4, 0.7], nargs='+', type=float, help='drop ops probability in each stage') + # parser.add_argument('--add_layers', action='append', help='add layers in each stage') + # parser.add_argument('--dropped_ops', action='append', help='drop ops in each stage') + # parser.add_argument('--dropout_rates', action='append', help='drop ops probability in each stage') + parser.add_argument("--channels", default=16, type=int) + parser.add_argument("--log_frequency", default=50, type=int) + parser.add_argument("--class_num", default=10, type=int) + parser.add_argument("--unrolled", default=False, action="store_true") + args = parser.parse_args() + + mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) + init_logger(args.log_path, "info") + set_seed(args.trial_id) + # args.add_layers = list_str2int(args.add_layers) + # args.dropped_ops = list_str2int(args.dropped_ops) + # args.dropout_rates = list_str2int(args.dropout_rates) + logger.info(args) + + logger.info("initializing pdarts trainer") + trainer = PdartsTrainer( + init_layers=args.init_layers, + pdarts_num_layers=args.add_layers, + pdarts_num_to_drop=args.dropped_ops, + pdarts_dropout_rates=args.dropout_rates, + num_epochs=args.epochs, + num_pre_epochs=args.pre_epochs, + model_lr=args.model_lr, + arch_lr=args.arch_lr, + batch_size=args.batch_size, + class_num=args.class_num, + channels=args.channels, + result_path=args.result_path, + log_frequency=args.log_frequency, + unrolled=args.unrolled, + data_dir = args.data_dir, + search_space_path=args.search_space_path, + best_selected_space_path=args.best_selected_space_path + ) + + logger.info("training") + start_time = time.time() + trainer.train(validate=True) + # result = trainer.result + cost_time = time.time() - start_time + # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} + logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) + with open(args.result_path, "a") as file: + file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) diff --git a/dubhe-tadl/pdarts/pdartsmutator.py b/dubhe-tadl/pdarts/pdartsmutator.py new file mode 100644 index 0000000..a8df379 --- /dev/null +++ b/dubhe-tadl/pdarts/pdartsmutator.py @@ -0,0 +1,201 @@ +import copy + +import numpy as np +import torch +import logging +from collections import OrderedDict +from torch import nn + +from pytorch.darts.dartsmutator import DartsMutator +from pytorch.mutables import LayerChoice, InputChoice + +logger = logging.getLogger(__name__) + +class PdartsMutator(DartsMutator): + """ + It works with PdartsTrainer to calculate ops weights, + and drop weights in different PDARTS epochs. + """ + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): + self.pdarts_epoch_index = pdarts_epoch_index + self.pdarts_num_to_drop = pdarts_num_to_drop + # save the last two switches and choices for restrict skip + self.last_two_switches = None + self.last_two_choices = None + + if switches is None: + self.switches = {} + else: + self.switches = switches + + super(PdartsMutator, self).__init__(model) + + # this loop go through mutables with different keys, + # it's mainly to update length of choices. + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + switches = self.switches.get(mutable.key, [True for j in range(len(mutable))]) + # choices = self.choices[mutable.key] + + operations_count = np.sum(switches) + # +1 and -1 are caused by zero operation in darts network + # the zero operation is not in choices list(switches) in network, but its weight are in, + # so it needs one more weights and switch for zero. + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1)) + self.switches[mutable.key] = switches + + # update LayerChoice instances in model, + # it's physically remove dropped choices operations. + for module in self.model.modules(): + if isinstance(module, LayerChoice): + switches = self.switches.get(module.key) + choices = self.choices[module.key] + if len(module) > len(choices): + # from last to first, so that it won't effect previous indexes after removed one. + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del module[index] + assert len(module) <= len(choices), "Failed to remove dropped choices." + + def export(self, last, switches): + # In last pdarts_epoches, need to restrict skipconnection + # Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length. + if last: + # restrict Up to 2 skipconnect (normal cell only) + name = "normal" + max_num = 2 + skip_num = self.check_skip_num(name, switches) + logger.info("Initially, the number of skipconnect is {}.".format(skip_num)) + while skip_num > max_num: + logger.info("Restricting {} skipconnect to {}.".format(skip_num, max_num)) + logger.info("Original normal_switch is {}.".format(switches)) + # update self.choices setting skip prob to 0 and self.switches setting skip prob to False + switches = self.delete_min_sk(name, switches) + logger.info("Restricted normal_switch is {}.".format(switches)) + skip_num = self.check_skip_num(name, switches) + + # from bool result convert to human readable by Mutator export() + results = super().sample_final() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + # As some operations are dropped physically, + # so it needs to fill back false to track dropped operations. + trained_result = results[mutable.key] + trained_index = 0 + switches = self.switches[mutable.key] + result = torch.Tensor(switches).bool() + for index in range(len(result)): + if result[index]: + result[index] = trained_result[trained_index] + trained_index += 1 + results[mutable.key] = result + return results + + def drop_paths(self): + """ + This method is called when a PDARTS epoch is finished. + It prepares switches for next epoch. + candidate operations with False switch will be doppped in next epoch. + """ + all_switches = copy.deepcopy(self.switches) + for key in all_switches: + switches = all_switches[key] + idxs = [] + for j in range(len(switches)): + if switches[j]: + idxs.append(j) + sorted_weights = self.choices[key].data.cpu().numpy()[:-1] + drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]] + for idx in drop: + switches[idxs[idx]] = False + return all_switches + + + def check_skip_num(self, name, switches): + counter = 0 + for key in switches: + if name in key: + # zero operation not in switches, so "skipconnect" in 2 + if switches[key][2]: + counter += 1 + return counter + + def delete_min_sk(self, name, switches): + def _get_sk_idx(key, switches): + if not switches[key][2]: + idx = -1 + else: + idx = 0 + for i in range(2): + # switches has 1 True, self.switches has 2 True + if self.switches[key][i]: + idx += 1 + return idx + sk_choices = [1.0 for i in range(14)] + sk_keys = [None for i in range(14)] # key has skip connection + sk_choices_idx = -1 + for key in switches: + if name in key: + # default key in order + sk_choices_idx += 1 + idx = _get_sk_idx(key, switches) + if not idx == -1: + sk_keys[sk_choices_idx] = key + sk_choices[sk_choices_idx] = self.choices[key][idx] + min_sk_idx = np.argmin(sk_choices) + idx = _get_sk_idx(sk_keys[min_sk_idx], switches) + # modify self.choices or copy.deepcopy ? + self.choices[sk_keys[min_sk_idx]][idx] = 0.0 + # modify self.switches or copy.deepcopy ? + # self.switches indicate last two switches, and switches indicate present(last) switches + self.switches[sk_keys[min_sk_idx]][2] = False + switches[sk_keys[min_sk_idx]][2] = False + return switches + + + def _generate_search_space(self): + """ + Generate search space from mutables. + Here is the search space format: + :: + { key_name: {"_type": "layer_choice", + "_value": ["conv1", "conv2"]} } + { key_name: {"_type": "input_choice", + "_value": {"candidates": ["in1", "in2"], + "n_chosen": 1}} } + Returns + ------- + dict + the generated search space + """ + res = OrderedDict() + res["op_list"] = OrderedDict() + res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()} + keys = [] + for mutable in self.mutables: + # for now we only generate flattened search space + if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36: + break + + if isinstance(mutable, LayerChoice): + key = mutable.key + if key not in keys: + val = mutable.names + if not res["op_list"]: + res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} + node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][node_type][key] = "op_list" + keys.append(key) + + elif isinstance(mutable, InputChoice): + key = mutable.key + if key not in keys: + node_type = "normal_cell" if "normal" in key else "reduction_cell" + res["search_space"][node_type][key] = {"_type": "input_choice", + "_value": {"candidates": mutable.choose_from, + "n_chosen": mutable.n_chosen}} + keys.append(key) + else: + raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) + + return res \ No newline at end of file diff --git a/dubhe-tadl/pdarts/pdartstrainer.py b/dubhe-tadl/pdarts/pdartstrainer.py new file mode 100644 index 0000000..ad4a43b --- /dev/null +++ b/dubhe-tadl/pdarts/pdartstrainer.py @@ -0,0 +1,167 @@ +import os +import logging +import torch +import torch.nn as nn +import numpy as np +from collections import OrderedDict +import json +from pytorch.callbacks import LRSchedulerCallback +from pytorch.trainer import BaseTrainer, TorchTensorEncoder +from pytorch.utils import dump_global_result + +from model import CNN +from pdartsmutator import PdartsMutator +from pytorch.darts.utils import accuracy +from pytorch.darts import datasets +from pytorch.darts.dartstrainer import DartsTrainer + +logger = logging.getLogger(__name__) + +class PdartsTrainer(BaseTrainer): + """ + This trainer implements the PDARTS algorithm. + PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network. + This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows. + pdarts_num_layers means how many layers more than first epoch. + pdarts_num_to_drop means how many candidate operations should be dropped in each epoch. + So that the grew network can in similar size. + """ + + def __init__(self, init_layers, pdarts_num_layers, pdarts_num_to_drop, pdarts_dropout_rates, num_epochs, num_pre_epochs, model_lr, class_num, + arch_lr, channels, batch_size, result_path, log_frequency, unrolled, data_dir, search_space_path, + best_selected_space_path, device=None, workers=4): + super(PdartsTrainer, self).__init__() + self.init_layers = init_layers + self.class_num = class_num + self.channels = channels + self.model_lr = model_lr + self.num_epochs = num_epochs + self.class_num = class_num + self.pdarts_num_layers = pdarts_num_layers + self.pdarts_num_to_drop = pdarts_num_to_drop + self.pdarts_dropout_rates = pdarts_dropout_rates + self.pdarts_epoches = len(pdarts_num_to_drop) + self.search_space_path = search_space_path + self.best_selected_space_path = best_selected_space_path + + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset( + "cifar10", root=data_dir) + self.darts_parameters = { + "metrics": lambda output, target: accuracy(output, target, topk=(1,)), + "arch_lr": arch_lr, + "num_epochs": num_epochs, + "num_pre_epochs": num_pre_epochs, + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + "batch_size": batch_size, + "result_path": result_path, + "workers": workers, + "device": device, + "log_frequency": log_frequency, + "unrolled": unrolled, + "search_space_path": None + } + + def train(self, validate=False): + switches = None + last = False + for epoch in range(self.pdarts_epoches): + if epoch == self.pdarts_epoches - 1: + last = True + # create network for each stage + layers = self.init_layers + self.pdarts_num_layers[epoch] + init_dropout_rate = float(self.pdarts_dropout_rates[epoch]) + model = CNN(32, 3, self.channels, self.class_num, layers, + init_dropout_rate, n_nodes=4, search=True) + criterion = nn.CrossEntropyLoss() + optim = torch.optim.SGD( + model.parameters(), self.model_lr, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optim, self.num_epochs, eta_min=0.001) + + logger.info( + "############Start PDARTS training epoch %s############", epoch) + self.mutator = PdartsMutator( + model, epoch, self.pdarts_num_to_drop, switches) + if epoch == 0: + # only write original search space in first stage + search_space = self.mutator._generate_search_space() + dump_global_result(self.search_space_path, + search_space) + + darts_callbacks = [] + if lr_scheduler is not None: + darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) + # darts_callbacks.append(ArchitectureCheckpoint( + # os.path.join(self.selected_space_path, "stage_{}".format(epoch)))) + self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, + optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters) + + for train_epoch in range(self.darts_parameters["num_epochs"]): + for callback in darts_callbacks: + callback.on_epoch_begin(train_epoch) + + # training + logger.info("Epoch %d Training", train_epoch) + if train_epoch < self.darts_parameters["num_pre_epochs"]: + dropout_rate = init_dropout_rate * \ + (self.darts_parameters["num_epochs"] - train_epoch - + 1) / self.darts_parameters["num_epochs"] + else: + # scale_factor = 0.2 + dropout_rate = init_dropout_rate * \ + np.exp(-(epoch - + self.darts_parameters["num_pre_epochs"]) * 0.2) + + model.drop_path_prob(search=True, p=dropout_rate) + self.trainer.train_one_epoch(train_epoch) + + if validate: + # validation + logger.info("Epoch %d Validating", train_epoch + 1) + self.trainer.validate_one_epoch( + train_epoch, log_print=True if last else False) + + for callback in darts_callbacks: + callback.on_epoch_end(train_epoch) + + switches = self.mutator.drop_paths() + + # In last pdarts_epoches, need to restrict skipconnection and save best structure + if last: + res = OrderedDict() + op_value = [value for value in search_space["op_list"]["_value"] if value != 'none'] + res["op_list"] = search_space["op_list"] + res["op_list"]["_value"] = op_value + res["best_selected_space"] = self.mutator.export(last, switches) + logger.info(res) + dump_global_result(self.best_selected_space_path, res) + + def validate(self): + self.trainer.validate() + + def export(self, file, last, switches): + self.mutator.export(last, switches) + mutator_export = self.mutator.export() + + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) + + def checkpoint(self, file_path, epoch): + if isinstance(self.model, nn.DataParallel): + child_model_state_dict = self.model.module.state_dict() + else: + child_model_state_dict = self.model.state_dict() + + save_state = {'child_model_state_dict': child_model_state_dict, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'epoch': epoch} + + dest_path = os.path.join( + file_path, "best_checkpoint_epoch_{}.pth.tar".format(epoch)) + logger.info("Saving model to %s", dest_path) + torch.save(save_state, dest_path) + raise NotImplementedError("Not implemented yet") + + diff --git a/dubhe-tadl/pdarts/reademe.md b/dubhe-tadl/pdarts/reademe.md new file mode 100644 index 0000000..a1376e3 --- /dev/null +++ b/dubhe-tadl/pdarts/reademe.md @@ -0,0 +1,92 @@ +# train stage +`python pdarts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --model_lr 0.025 --arch_lr 3e-4 --epochs 2 --pre_epochs 1 --batch_size 64 --channels 16 --init_layers 5 --add_layer 0 6 12 --dropped_ops 3 2 1 --dropout_rates 0.1 0.4 0.7` + +# select stage +`python pdarts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json'` + +# retrain stage +`python pdarts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 2 --lr 0.025 --layers 20 --channels 36` + +# output file +`result.json` +``` +{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} +{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} +{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} +{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} +``` + +`search_space.json` +``` +{ + "op_list": { + "_type": "layer_choice", + "_value": [ + "maxpool", + "avgpool", + "skipconnect", + "sepconv3x3", + "sepconv5x5", + "dilconv3x3", + "dilconv5x5", + "none" + ] + }, + "search_space": { + "normal_n2_p0": "op_list", + "normal_n2_p1": "op_list", + "normal_n2_switch": { + "_type": "input_choice", + "_value": { + "candidates": [ + "normal_n2_p0", + "normal_n2_p1" + ], + "n_chosen": 2 + } + }, + + ... + } +``` + +`best_selected_space.json` +``` +{ + { + "op_list": { + "_type": "layer_choice", + "_value": [ + "maxpool", + "avgpool", + "skipconnect", + "sepconv3x3", + "sepconv5x5", + "dilconv3x3", + "dilconv5x5" + ] + }, + "best_selected_space": { + "normal_n2_p0": [ + false, + false, + false, + false, + true, + false, + false + ], + "normal_n2_p1": [ + true, + false, + false, + false, + false, + false, + false + ], + ... +} +``` \ No newline at end of file diff --git a/dubhe-tadl/retrainer.py b/dubhe-tadl/retrainer.py new file mode 100644 index 0000000..4424bc6 --- /dev/null +++ b/dubhe-tadl/retrainer.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + + +class Retrainer(ABC): + + """ + Train the best performance model from scratch without structure optimization. + To implement a new selector, users need to implement: + method: "train" + method: "__init__" + super().__init__() must be called in __init__ method + + parameters: + ----------- + candidates: candidates to be evaluated + """ + + @abstractmethod + def train(self): + """ + Override the method to train. + """ + raise NotImplementedError + + def validate(self): + """ + Override the method to validate. + """ + raise NotImplementedError + + def export(self, file): + """ + Override the method to export to file. + + Parameters + ---------- + file : str + File path to export to. + """ + raise NotImplementedError + + def checkpoint(self): + """ + Override to dump a checkpoint. + """ + raise NotImplementedError \ No newline at end of file diff --git a/dubhe-tadl/selector.py b/dubhe-tadl/selector.py new file mode 100644 index 0000000..29145c3 --- /dev/null +++ b/dubhe-tadl/selector.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod + + +class Selector(ABC): + """ + choose the best model from a group of candidates. + To implement a new selector, users need to implement: + method: "fit" + method: "__init__" + super().__init__() must be called in __init__ method + + parameters: + ----------- + candidates: candidates to be evaluated + + ##### Examples ##### + + # class HPOSelector(Selector): + # def __init__(self, *args, single_candidate=True): + # super().__init__(single_candidate) + # self.args = args + + # def fit(self): + # + # # only one candatite, function passed + # + # pass + + ########### + + """ + + @abstractmethod + def __init__(self, single_candidate=True): + self.single_candidate = single_candidate + self._valid() + + @abstractmethod + def fit(self, candidates=None): + """ + evaluate the candidates to select the best one. + any optimization algos could be implement here. + if the inputs has only one candidates, just return the candidate directly + """ + raise NotImplementedError + + def _valid(self, ): + if self.single_candidate: + print("### single model, selecting finished ###") + exit(0) + + + + + + diff --git a/dubhe-tadl/spos/algorithms/random/__init__.py b/dubhe-tadl/spos/algorithms/random/__init__.py new file mode 100644 index 0000000..b410226 --- /dev/null +++ b/dubhe-tadl/spos/algorithms/random/__init__.py @@ -0,0 +1 @@ +from .mutator import RandomMutator diff --git a/dubhe-tadl/spos/algorithms/random/mutator.py b/dubhe-tadl/spos/algorithms/random/mutator.py new file mode 100644 index 0000000..f302db5 --- /dev/null +++ b/dubhe-tadl/spos/algorithms/random/mutator.py @@ -0,0 +1,36 @@ +import torch +import torch.nn.functional as F + +from nni.nas.pytorch.mutator import Mutator +from nni.nas.pytorch.mutables import LayerChoice, InputChoice + + +class RandomMutator(Mutator): + """ + Random mutator that samples a random candidate in the search space each time ``reset()``. + It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior. + """ + + def sample_search(self): + """ + Sample a random candidate. + """ + result = dict() + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + gen_index = torch.randint(high=len(mutable), size=(1, )) + result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() + elif isinstance(mutable, InputChoice): + if mutable.n_chosen is None: + result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() + else: + perm = torch.randperm(mutable.n_candidates) + mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] + result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable + return result + + def sample_final(self): + """ + Same as :meth:`sample_search`. + """ + return self.sample_search() diff --git a/dubhe-tadl/spos/algorithms/spos/__init__.py b/dubhe-tadl/spos/algorithms/spos/__init__.py new file mode 100644 index 0000000..50b568e --- /dev/null +++ b/dubhe-tadl/spos/algorithms/spos/__init__.py @@ -0,0 +1,3 @@ +from .evolution import SPOSEvolution +from .mutator import SPOSSupernetTrainingMutator +from .trainer import SPOSSupernetTrainer diff --git a/dubhe-tadl/spos/algorithms/spos/evolution.py b/dubhe-tadl/spos/algorithms/spos/evolution.py new file mode 100644 index 0000000..9a20ec8 --- /dev/null +++ b/dubhe-tadl/spos/algorithms/spos/evolution.py @@ -0,0 +1,223 @@ +import os +import re +import json +import logging +from collections import deque + +import numpy as np +from nni.tuner import Tuner +# from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE # TODO +LAYER_CHOICE = "layer_choice" +INPUT_CHOICE = "input_choice" + + +_logger = logging.getLogger(__name__) + + +class SPOSEvolution(Tuner): + """ + SPOS evolution tuner. + + Parameters + ---------- + max_epochs : int + Maximum number of epochs to run. + num_select : int + Number of survival candidates of each epoch. + num_population : int + Number of candidates at the start of each epoch. If candidates generated by + crossover and mutation are not enough, the rest will be filled with random + candidates. + m_prob : float + The probability of mutation. + num_crossover : int + Number of candidates generated by crossover in each epoch. + num_mutation : int + Number of candidates generated by mutation in each epoch. + """ + + def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1, + num_crossover=25, num_mutation=25): + assert num_population >= num_select + self.max_epochs = max_epochs + self.num_select = num_select + self.num_population = num_population + self.m_prob = m_prob + self.num_crossover = num_crossover + self.num_mutation = num_mutation + self.epoch = 0 + self.candidates = [] + self.search_space = None + self.random_state = np.random.RandomState(0) + + # async status + self._to_evaluate_queue = deque() + self._sending_parameter_queue = deque() + self._pending_result_ids = set() + self._reward_dict = dict() + self._id2candidate = dict() + self._st_callback = None + + def update_search_space(self, search_space): + """ + Handle the initialization/update event of search space. + """ + self._search_space = search_space + self._next_round() + + def _next_round(self): + _logger.info("Epoch %d, generating...", self.epoch) + if self.epoch == 0: + self._get_random_population() + self.export_results(self.candidates) + else: + best_candidates = self._select_top_candidates() + self.export_results(best_candidates) + if self.epoch >= self.max_epochs: + return + self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates) + self._get_random_population() + self.epoch += 1 + + def _random_candidate(self): + chosen_arch = dict() + for key, val in self._search_space.items(): + if val["_type"] == LAYER_CHOICE: + choices = val["_value"] + index = self.random_state.randint(len(choices)) + chosen_arch[key] = {"_value": choices[index], "_idx": index} + elif val["_type"] == INPUT_CHOICE: + raise NotImplementedError("Input choice is not implemented yet.") + return chosen_arch + + def _add_to_evaluate_queue(self, cand): + _logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand)) + self._reward_dict[self._hashcode(cand)] = 0. + self._to_evaluate_queue.append(cand) + + def _get_random_population(self): + while len(self.candidates) < self.num_population: + cand = self._random_candidate() + if self._is_legal(cand): + _logger.info("Random candidate generated.") + self._add_to_evaluate_queue(cand) + self.candidates.append(cand) + + def _get_crossover(self, best): + result = [] + for _ in range(10 * self.num_crossover): + cand_p1 = best[self.random_state.randint(len(best))] + cand_p2 = best[self.random_state.randint(len(best))] + assert cand_p1.keys() == cand_p2.keys() + cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k] + for k in cand_p1.keys()} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_crossover: + break + _logger.info("Found %d architectures with crossover.", len(result)) + return result + + def _get_mutation(self, best): + result = [] + for _ in range(10 * self.num_mutation): + cand = best[self.random_state.randint(len(best))].copy() + mutation_sample = np.random.random_sample(len(cand)) + for s, k in zip(mutation_sample, cand): + if s < self.m_prob: + choices = self._search_space[k]["_value"] + index = self.random_state.randint(len(choices)) + cand[k] = {"_value": choices[index], "_idx": index} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_mutation: + break + _logger.info("Found %d architectures with mutation.", len(result)) + return result + + def _get_architecture_repr(self, cand): + # 只取出_value的值 --> "{2,3,2,1,3,2,1,0,...}" + return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1", + self._hashcode(cand)) + + def _is_legal(self, cand): + if self._hashcode(cand) in self._reward_dict: + return False + return True + + def _select_top_candidates(self): + reward_query = lambda cand: self._reward_dict[self._hashcode(cand)] + _logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates))) + result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select] + _logger.info("Best candidate rewards: %s", list(map(reward_query, result))) + return result + + @staticmethod + def _hashcode(d): + return json.dumps(d, sort_keys=True) + + def _bind_and_send_parameters(self): + """ + There are two types of resources: parameter ids and candidates. This function is called at + necessary times to bind these resources to send new trials with st_callback. + """ + result = [] + while self._sending_parameter_queue and self._to_evaluate_queue: + parameter_id = self._sending_parameter_queue.popleft() + parameters = self._to_evaluate_queue.popleft() + self._id2candidate[parameter_id] = parameters + result.append(parameters) + self._pending_result_ids.add(parameter_id) + self._st_callback(parameter_id, parameters) + _logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters)) + return result + + def generate_multiple_parameters(self, parameter_id_list, **kwargs): + """ + Callback function necessary to implement a tuner. This will put more parameter ids into the + parameter id queue. + """ + if "st_callback" in kwargs and self._st_callback is None: + self._st_callback = kwargs["st_callback"] + for parameter_id in parameter_id_list: + self._sending_parameter_queue.append(parameter_id) + self._bind_and_send_parameters() + return [] # always not use this. might induce problem of over-sending + + def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + """ + Callback function. Receive a trial result. + """ + _logger.info("Candidate %d, reported reward %f", parameter_id, value) + self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value + + def trial_end(self, parameter_id, success, **kwargs): + """ + Callback function when a trial is ended and resource is released. + """ + self._pending_result_ids.remove(parameter_id) + if not self._pending_result_ids and not self._to_evaluate_queue: + # a new epoch now + self._next_round() + assert self._st_callback is not None + self._bind_and_send_parameters() + + def export_results(self, result): + """ + Export a number of candidates to `checkpoints` dir. + + Parameters + ---------- + result : dict + Chosen architectures to be exported. + """ + os.makedirs("checkpoints", exist_ok=True) + for i, cand in enumerate(result): + converted = dict() + for cand_key, cand_val in cand.items(): + onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))] + converted[cand_key] = onehot + with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp: + json.dump(converted, fp) diff --git a/dubhe-tadl/spos/algorithms/spos/mutator.py b/dubhe-tadl/spos/algorithms/spos/mutator.py new file mode 100644 index 0000000..85981e6 --- /dev/null +++ b/dubhe-tadl/spos/algorithms/spos/mutator.py @@ -0,0 +1,65 @@ +import sys +sys.path.insert(0, "../../") + +import logging +import numpy as np +from pytorch.algorithms.random import RandomMutator + +_logger = logging.getLogger(__name__) + + +class SPOSSupernetTrainingMutator(RandomMutator): + """ + A random mutator with flops limit. + + Parameters + ---------- + model : nn.Module + PyTorch model. + flops_func : callable + Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func` + is None, functions related to flops will be deactivated. + flops_lb : number + Lower bound of flops. + flops_ub : number + Upper bound of flops. + flops_bin_num : number + Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more + uniform, but the sampling will be slower. + flops_sample_timeout : int + Maximum number of attempts to sample before giving up and use a random candidate. + """ + def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None, + flops_bin_num=7, flops_sample_timeout=500): + + super().__init__(model) + self._flops_func = flops_func + if self._flops_func is not None: + self._flops_bin_num = flops_bin_num + self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)] + self._flops_sample_timeout = flops_sample_timeout + + def sample_search(self): + """ + Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly + relative to flops. + + Returns + ------- + dict + """ + if self._flops_func is not None: + for times in range(self._flops_sample_timeout): + idx = np.random.randint(self._flops_bin_num) + cand = super().sample_search() + if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]: + _logger.debug("Sampled candidate flops %f in %d times.", cand, times) + return cand + _logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout) + return super().sample_search() + + def sample_final(self): + """ + Implement only to suffice the interface of Mutator. + """ + return self.sample_search() diff --git a/dubhe-tadl/spos/algorithms/spos/trainer.py b/dubhe-tadl/spos/algorithms/spos/trainer.py new file mode 100644 index 0000000..d99e954 --- /dev/null +++ b/dubhe-tadl/spos/algorithms/spos/trainer.py @@ -0,0 +1,100 @@ +import logging + +import torch +from pytorch.trainer import Trainer +from pytorch.utils import AverageMeterGroup + +from .mutator import SPOSSupernetTrainingMutator + +logger = logging.getLogger(__name__) + + +class SPOSSupernetTrainer(Trainer): + """ + This trainer trains a supernet that can be used for evolution search. + + Parameters + ---------- + model : nn.Module + Model with mutables. + mutator : Mutator + A mutator object that has been initialized with the model. + loss : callable + Called with logits and targets. Returns a loss tensor. + metrics : callable + Returns a dict that maps metrics keys to metrics data. + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + train_loader : iterable + Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. + dataset_valid : iterable + Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. + batch_size : int + Batch size. + workers: int + Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future. + device : torch.device + Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will + automatic detects GPU and selects GPU first. + log_frequency : int + Number of mini-batches to log metrics. + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + """ + + def __init__(self, model, loss, metrics, + optimizer, num_epochs, train_loader, valid_loader, + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, + callbacks=None): + assert torch.cuda.is_available() + super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model), + loss, metrics, optimizer, num_epochs, None, None, + batch_size, workers, device, log_frequency, callbacks) + + self.train_loader = train_loader + self.valid_loader = valid_loader + + def train_one_epoch(self, epoch): + self.model.train() + meters = AverageMeterGroup() + # print("length is {}".format(len(self.train_loader))) + length = len(self.train_loader) + for step, (x, y) in enumerate(self.train_loader): + x, y = x.to(self.device), y.to(self.device) + self.optimizer.zero_grad() + self.mutator.reset() + logits = self.model(x) + loss = self.loss(logits, y) + loss.backward() + self.optimizer.step() + + metrics = self.metrics(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.train_loader), meters) + if step>length: + break + + + def validate_one_epoch(self, epoch): + self.model.eval() + meters = AverageMeterGroup() + length = len(self.valid_loader) + with torch.no_grad(): + for step, (x, y) in enumerate(self.valid_loader): + x, y = x.to(self.device), y.to(self.device) + self.mutator.reset() + logits = self.model(x) + loss = self.loss(logits, y) + metrics = self.metrics(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.valid_loader), meters) + if step>length: + break diff --git a/dubhe-tadl/spos/blocks.py b/dubhe-tadl/spos/blocks.py new file mode 100644 index 0000000..4ce7a8b --- /dev/null +++ b/dubhe-tadl/spos/blocks.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + + +class ShuffleNetBlock(nn.Module): + """ + When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels. + """ + + def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True): + super().__init__() + assert stride in [1, 2] + assert ksize in [3, 5, 7] + self.channels = inp // 2 if stride == 1 else inp + self.inp = inp + self.oup = oup + self.mid_channels = mid_channels + self.ksize = ksize + self.stride = stride + self.pad = ksize // 2 + self.oup_main = oup - self.channels + self._affine = affine + assert self.oup_main > 0 + + self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence)) + + if stride == 2: + self.branch_proj = nn.Sequential( + # dw + nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, + groups=self.channels, bias=False), + nn.BatchNorm2d(self.channels, affine=affine), + # pw-linear + nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(self.channels, affine=affine), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + if self.stride == 2: + x_proj, x = self.branch_proj(x), x + else: + x_proj, x = self._channel_shuffle(x) + return torch.cat((x_proj, self.branch_main(x)), 1) + + def _decode_point_depth_conv(self, sequence): + result = [] + first_depth = first_point = True + pc = c = self.channels + for i, token in enumerate(sequence): + # compute output channels of this conv + if i + 1 == len(sequence): + assert token == "p", "Last conv must be point-wise conv." + c = self.oup_main + elif token == "p" and first_point: + c = self.mid_channels + if token == "d": + # depth-wise conv + assert pc == c, "Depth-wise conv must not change channels." + result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad, + groups=c, bias=False)) + result.append(nn.BatchNorm2d(c, affine=self._affine)) + first_depth = False + elif token == "p": + # point-wise conv + result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False)) + result.append(nn.BatchNorm2d(c, affine=self._affine)) + result.append(nn.ReLU(inplace=True)) + first_point = False + else: + raise ValueError("Conv sequence must be d and p.") + pc = c + return result + + def _channel_shuffle(self, x): + bs, num_channels, height, width = x.data.size() + assert (num_channels % 4 == 0) + x = x.reshape(bs * num_channels // 2, 2, height * width) + x = x.permute(1, 0, 2) + x = x.reshape(2, -1, num_channels // 2, height, width) + return x[0], x[1] + + +class ShuffleXceptionBlock(ShuffleNetBlock): + + def __init__(self, inp, oup, mid_channels, stride, affine=True): + super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine) diff --git a/dubhe-tadl/spos/data/op_flops_dict.pkl b/dubhe-tadl/spos/data/op_flops_dict.pkl new file mode 100644 index 0000000..5fdb31f Binary files /dev/null and b/dubhe-tadl/spos/data/op_flops_dict.pkl differ diff --git a/dubhe-tadl/spos/dataloader.py b/dubhe-tadl/spos/dataloader.py new file mode 100644 index 0000000..b83c32f --- /dev/null +++ b/dubhe-tadl/spos/dataloader.py @@ -0,0 +1,125 @@ +import os + +import nvidia.dali.ops as ops +import nvidia.dali.types as types +import torch.utils.data +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator + + +class HybridTrainPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1, + spos_pre=False): + super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) + color_space_type = types.BGR if spos_pre else types.RGB + self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) + self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR) # color_space_type + self.res = ops.RandomResizedCrop(device="gpu", size=crop, + interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) + self.twist = ops.ColorTwist(device="gpu") + self.jitter_rng = ops.Uniform(range=[0.6, 1.4]) + # self.cmnp = ops.CropMirrorNormalize(device="gpu", + # dtype = types.FLOAT, # output_dtype=types.FLOAT, + # output_layout=types.NCHW, + # # image_type=color_space_type, # 该功能被删掉了,在ImageDecoder中即可完成 + # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + self.cmnp = ops.CropMirrorNormalize(device="gpu", dtype = types.FLOAT, output_layout=types.NCHW, + mean= [0.485 * 255, 0.456 * 255, 0.406 * 255], + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + ) + self.coin = ops.CoinFlip(probability=0.5) + + def define_graph(self): + rng = self.coin() + self.jpegs, self.labels = self.input(name="Reader") + images = self.decode(self.jpegs) + images = self.res(images) + images = self.twist(images, saturation=self.jitter_rng(), + contrast=self.jitter_rng(), brightness=self.jitter_rng()) + output = self.cmnp(images, mirror=rng) # 临时删除,测试准确率为零是否是数据处理的原因 + return [output, self.labels] # output + + +class HybridValPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1, + spos_pre=False, shuffle=False): + super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id) + color_space_type = types.BGR if spos_pre else types.RGB + self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, + random_shuffle=shuffle) + self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR) + self.res = ops.Resize(device="gpu", resize_shorter=size, + interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR) + # self.cmnp = ops.CropMirrorNormalize(device="gpu", + # dtype = types.FLOAT, # output_dtype=types.FLOAT, + # output_layout=types.NCHW, + # crop=(crop, crop), + # # image_type=color_space_type, + # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], + # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + dtype = types.FLOAT, # output_dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(crop, crop), + # image_type=color_space_type, + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255], + std = [0.229 * 255, 0.224 * 255, 0.225 * 255]) + + + def define_graph(self): + self.jpegs, self.labels = self.input(name="Reader") + images = self.decode(self.jpegs) + images = self.res(images) + output = self.cmnp(images) + return [output, self.labels] + + +class ClassificationWrapper: + def __init__(self, loader, size): + self.loader = loader + self.size = size + + def __iter__(self): + return self + + def __next__(self): + data = next(self.loader) + return data[0]["data"], data[0]["label"].view(-1).long().cuda(device="cuda:0", non_blocking=True) # .cuda(non_blocking=True) + + def __len__(self): + return self.size + + +def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256, + spos_preprocessing=False, seed=12, shuffle=False, device_id=None): + world_size, local_rank = 1, 0 + if device_id is None: + device_id = torch.cuda.device_count() - 1 # use last gpu + if split == "train": + pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir=os.path.join(image_dir, "train"), seed=seed, + crop=crop, world_size=world_size, local_rank=local_rank, + spos_pre=spos_preprocessing) + elif split == "val": + pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir=os.path.join(image_dir, "val"), seed=seed, + crop=crop, size=val_size, world_size=world_size, local_rank=local_rank, + spos_pre=spos_preprocessing, shuffle=shuffle) + else: + raise AssertionError + pipeline.build() + num_samples = pipeline.epoch_size("Reader") + # fill_last_batch的设置 + # 参考这里, valid和train设置为一样的策略 + # https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_plugin_api.html?highlight=daliclassificationiterator#nvidia.dali.plugin.pytorch.DALIClassificationIterator + last_batch_policy = "" + last_batch_padded = True + return ClassificationWrapper( + DALIClassificationIterator(pipeline, + # size=num_samples, + last_batch_policy = last_batch_policy, + # last_batch_padded = last_batch_padded, + # fill_last_batch=split == "train", # 这个方法已经不建议使用了 + auto_reset=True), + (num_samples + batch_size - 1) // batch_size) diff --git a/dubhe-tadl/spos/evaluator.py b/dubhe-tadl/spos/evaluator.py new file mode 100644 index 0000000..a48a016 --- /dev/null +++ b/dubhe-tadl/spos/evaluator.py @@ -0,0 +1,202 @@ +import os +import json +import time +import random +import argparse +import numpy as np +# import logging +from itertools import cycle + +import sys +sys.path.append("../") +sys.path.append("../../") + +import torch +import torch.nn as nn + +os.environ["NNI_GEN_SEARCH_SPACE"] = "auto_gen_search_space.json" +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" + +from pytorch.fixed import apply_fixed_architecture +from pytorch.utils import AverageMeterGroup + +from dataloader import get_imagenet_iter_dali +from network import ShuffleNetV2OneShot, load_and_parse_state_dict +from utils import CrossEntropyLabelSmooth, accuracy + +# logger = logging.getLogger("nni.spos.tester") # "nni.spos.tester" +print("Evolution Beginning...") + + +class Evaluator: + + """ + retrain the BN layer in specified model and evaluate it + """ + + def __init__(self, imagenet_dir="/mnt/local/hanjiayi/imagenet", # imagenet dataset + checkpoint="./data/checkpoint-150000.pth.tar", # fine model from supernet + spos_preprocessing=True, # RGB or BGR + seed=42, # torch.manual_seed + workers=1, # the number of subprocess + train_batch_size=128, + train_iters=200, + test_batch_size=512, + log_frequency=10, + ): + + self.imagenet_dir = imagenet_dir + self.checkpoint = checkpoint + self.spos_preprocessing = spos_preprocessing + self.seed = seed + self.workers = workers + self.train_batch_size = train_batch_size + self.train_iters = train_iters + self.test_batch_size = test_batch_size + self.log_frequency = log_frequency + print("### program interval 1 ###") + self.model = ShuffleNetV2OneShot() + print("### program interval 2 ###") + + print("## test&retrain -- load model ## begin to load model") + self.model.load_state_dict(load_and_parse_state_dict(filepath=self.checkpoint)) + print("## test&retrain -- load model ## model loaded") + + + torch.manual_seed(self.seed) + torch.cuda.manual_seed_all(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) + torch.backends.cudnn.deterministic = True + + assert torch.cuda.is_available() + + self.criterion = CrossEntropyLabelSmooth(1000, 0.1) + + print("##### load training data #####") + self.train_loader = get_imagenet_iter_dali("train", self.imagenet_dir, self.train_batch_size, self.workers, + spos_preprocessing=self.spos_preprocessing, + seed=self.seed, device_id=0) + print("##### training data loaded finished #####") + + print("##### load validating data #####") + self.val_loader = get_imagenet_iter_dali("val", self.imagenet_dir, self.test_batch_size, self.workers, + spos_preprocessing=self.spos_preprocessing, shuffle=True, + seed=self.seed, device_id=0) + print("##### validating data loaded finished #####") + + def retrain_bn(self, model, criterion, max_iters, log_freq, loader): + with torch.no_grad(): + # logger.info("Clear BN statistics...") + print("clear BN statistics") + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.running_mean = torch.zeros_like(m.running_mean) + m.running_var = torch.ones_like(m.running_var) + + # logger.info("Train BN with training set (BN sanitize)...") + print("Train BN with training set (BN sanitize)...") + model.train() + meters = AverageMeterGroup() + start_time = time.time() + + for step in range(max_iters): + inputs, targets = next(loader) + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == max_iters: + # logger.info("Train Step [%d/%d] %s time %.3fs ", step + 1, max_iters, meters, time.time() - start_time) + print("Train Step [%d/%d] %s time %.3fs "% (step + 1, max_iters, meters, time.time() - start_time)) + + def test_acc(self, model, criterion, log_freq, loader): + # logger.info("Start testing...") + print("start testing...") + model.eval() + meters = AverageMeterGroup() + start_time = time.time() + + with torch.no_grad(): + for step, (inputs, targets) in enumerate(loader): + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == len(loader): + # logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f", + # step + 1, len(loader), time.time() - start_time, + # meters.acc1.avg, meters.acc5.avg, meters.loss.avg) + print("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f"% + (step + 1, len(loader), time.time() - start_time, + meters.acc1.avg, meters.acc5.avg, meters.loss.avg)) + if step>len(loader): # 遍历一遍就停止 + break + return meters.acc1.avg + + def evaluate_acc(self, model, criterion, loader_train, loader_test): + + self.retrain_bn(model, criterion, self.train_iters, self.log_frequency, loader_train) # todo + acc = self.test_acc(model, criterion, self.log_frequency, loader_test) + assert isinstance(acc, float) + torch.cuda.empty_cache() + return acc + + def eval_model(self, epoch, architecture): + + # evaluate the model + + print("## test&retrain -- apply architecture ## begin to apply architecture to model") + apply_fixed_architecture(self.model, architecture) + print("## test&retrain -- apply architecture ## architecture applied") + + self.model.cuda(0) + self.train_loader = cycle(self.train_loader) + acc = self.evaluate_acc(self.model, self.criterion, self.train_loader, self.val_loader) + + # 把模型最终的准确率写入一个文件中 + os.makedirs("./acc", exist_ok=True) + with open("./acc/{}".format(architecture[-12:]), "w") as f: # [-12:] 代表没有路径的文件名 + # {filename1: acc, + # filename2: acc, + # 000_000.json: acc, + # 000_001.json: acc, + # ...... + # } + json.dump({architecture: acc}, f) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("SPOS Candidate Evaluator") + parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet + parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar + parser.add_argument("--spos-preprocessing", default=True, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") # , action="store_true" + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--workers", type=int, default=1) # 线程数 + parser.add_argument("--train-batch-size", type=int, default=128) + parser.add_argument("--train-iters", type=int, default=200) + parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200 + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval") + parser.add_argument("--epoch", type=int, default=0, help="when epoch=0, this file should generate an architecture file") + + args = parser.parse_args() + + evl = Evaluator(imagenet_dir=args.imagenet_dir, # imagenet dataset + checkpoint=args.checkpoint, # fine model from supernet + spos_preprocessing=args.spos_preprocessing, # RGB or BGR + seed=args.seed, # torch.manual_seed + workers=args.workers, # the number of subprocess + train_batch_size=args.train_batch_size, + train_iters=args.train_iters, + test_batch_size=args.test_batch_size, + log_frequency=args.log_frequency, + ) + + + evl.eval_model(args.epoch, args.architecture) \ No newline at end of file diff --git a/dubhe-tadl/spos/evolution_tuner.py b/dubhe-tadl/spos/evolution_tuner.py new file mode 100644 index 0000000..ad59b1d --- /dev/null +++ b/dubhe-tadl/spos/evolution_tuner.py @@ -0,0 +1,357 @@ +import os +import re +import sys +import json +import pickle +import logging +import subprocess +import numpy as np +from collections import deque + +from evaluator import Evaluator +from network import ShuffleNetV2OneShot, PARSED_FLOPS + +LAYER_CHOICE = "layer_choice" +INPUT_CHOICE = "input_choice" + + +_logger = logging.getLogger(__name__) + + +class SPOSEvolution: + """ + SPOS evolution tuner. + + Parameters + ---------- + max_epochs : int + Maximum number of epochs to run. + num_select : int + Number of survival candidates of each epoch. + num_population : int + Number of candidates at the start of each epoch. If candidates generated by + crossover and mutation are not enough, the rest will be filled with random + candidates. + m_prob : float + The probability of mutation. + num_crossover : int + Number of candidates generated by crossover in each epoch. + num_mutation : int + Number of candidates generated by mutation in each epoch. + """ + + def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1, + num_crossover=25, num_mutation=25, epoch=0): + assert num_population >= num_select + self.max_epochs = max_epochs + self.num_select = num_select + self.num_population = num_population + self.m_prob = m_prob + self.num_crossover = num_crossover + self.num_mutation = num_mutation + self.epoch = epoch + self.search_space = None + self.random_state = np.random.RandomState(0) + # self.evl = Evaluator() + + # async status + self._to_evaluate_queue = deque() + self._sending_parameter_queue = deque() + self._pending_result_ids = set() + self._reward_dict = dict() + self._id2candidate = dict() + self._st_callback = None + self.cand_path = "./checkpoints" + self.acc_path = "./acc" + self.candidates = [] if epoch == 0 else self.load_candidates() # 第一轮初始尚未有生成的种群 + + def load_candidates(self): + # 从self.export_result()写入文件的候选模型,需要读入 + # {"LayerChoice1": [false, false, false, true], ... } -> {"LayerChoice1": {"_idx":3, "_value":"3"}, ... } + print("## evolution -- load ## begin to load candidates in evolution...\n") + file_dir, _, files = next(os.walk(self.cand_path)) + files = [i for i in files if "%03d_"%(self.epoch-1) in i] + + def get_true_index(l): + return [i for i in range(len(l)) if l[i]][0] + + candidates = [] + for file in files: + with open(os.path.join(file_dir, file), "r") as f: + candidate = json.load(f) + + # 转换成合适的形式 + cand = {} + for key, value in candidate.items(): + v = get_true_index(value) + value = {"_value":str(v), "_idx":int(v)} + cand.update({key:value}) + + candidates.append(cand) + print("## evolution -- load ## candidates loaded \n") + return candidates + + def load_id2candidate(self): + with open("./id2cand/%03d_id2candidate.json"%(self.epoch - 1), "r") as f: + self.id2candidate = json.load(f) + + def update_search_space(self, search_space): + """ + Handle the initialization/update event of search space. + """ + print("## evolution -- update ## updating search space") + self._search_space = search_space + self._next_round() + print("## evolution -- update ## search space updated") + + def _next_round(self): + _logger.info("Epoch %d, generating...", self.epoch) + if self.epoch == 0: + self._get_random_population() + self.export_results(self.candidates) + + self.evaluate_cands() # 评估全部的模型 + else: + self.load_id2candidate() + self.receive_trial_result() + best_candidates = self._select_top_candidates() + if self.epoch >= self.max_epochs: + return + self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates) + self._get_random_population() + self.export_results(self.candidates) + self.evaluate_cands() # 评估全部的模型 + self.epoch += 1 + + def _random_candidate(self): + chosen_arch = dict() + for key, val in self._search_space.items(): + if val["_type"] == LAYER_CHOICE: + choices = val["_value"] + index = self.random_state.randint(len(choices)) + chosen_arch[key] = {"_value": choices[index], "_idx": index} + elif val["_type"] == INPUT_CHOICE: + raise NotImplementedError("Input choice is not implemented yet.") + return chosen_arch + + def _add_to_evaluate_queue(self, cand): + _logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand)) + self._reward_dict[self._hashcode(cand)] = 0. + self._to_evaluate_queue.append(cand) + + def _get_random_population(self): + while len(self.candidates) < self.num_population: + cand = self._random_candidate() + if self._is_legal(cand): + _logger.info("Random candidate generated.") + self._add_to_evaluate_queue(cand) + self.candidates.append(cand) + + def _get_crossover(self, best): + result = [] + for _ in range(10 * self.num_crossover): + cand_p1 = best[self.random_state.randint(len(best))] + cand_p2 = best[self.random_state.randint(len(best))] + assert cand_p1.keys() == cand_p2.keys() + cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k] + for k in cand_p1.keys()} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_crossover: + break + _logger.info("Found %d architectures with crossover.", len(result)) + return result + + def _get_mutation(self, best): + result = [] + for _ in range(10 * self.num_mutation): + cand = best[self.random_state.randint(len(best))].copy() + mutation_sample = np.random.random_sample(len(cand)) + for s, k in zip(mutation_sample, cand): + if s < self.m_prob: + choices = self._search_space[k]["_value"] + index = self.random_state.randint(len(choices)) + cand[k] = {"_value": choices[index], "_idx": index} + if self._is_legal(cand): + result.append(cand) + self._add_to_evaluate_queue(cand) + if len(result) >= self.num_mutation: + break + _logger.info("Found %d architectures with mutation.", len(result)) + return result + + def _get_architecture_repr(self, cand): + return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1", + self._hashcode(cand)) + + def _is_legal(self, cand): + if self._hashcode(cand) in self._reward_dict: + return False + return True + + # 将模型输出,并重训练、评估 + def evaluate_cands(self): + """ + 1、对输出的模型进行重训练 + 2、对重训练后的模型进行评估 + 以上内容通过tester.py脚本完成 + """ + print("## evolution -- evaluate ## begin to evaluate candidates...") + file_dir, _, files = next(os.walk(self.cand_path)) # 获取文件夹下的文件 + files = [i for i in files if "%03d_"%self.epoch in i] + + for file in files: + file = os.path.join(file_dir, file) + + # self.evl.eval_model(epoch=self.epoch, architecture=file) + python_interpreter_path = sys.executable + subprocess.run([python_interpreter_path,\ + "evaluator.py", "--architecture", file, "--epoch", str(self.epoch)]) + print("## evolution -- evaluate ## candidates evaluated") + + def _select_top_candidates(self): + print("## evolution -- select ## begin to select top candidates...") + reward_query = lambda cand: self._reward_dict[self._hashcode(cand)] + _logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates))) + result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select] + _logger.info("Best candidate rewards: %s", list(map(reward_query, result))) + print("## evolution -- select ## selected done") + return result + + @staticmethod + def _hashcode(d): + return json.dumps(d, sort_keys=True) + + def _bind_and_send_parameters(self): + """ + There are two types of resources: parameter ids and candidates. This function is called at + necessary times to bind these resources to send new trials with st_callback. + """ + result = [] + while self._sending_parameter_queue and self._to_evaluate_queue: + parameter_id = self._sending_parameter_queue.popleft() + parameters = self._to_evaluate_queue.popleft() + self._id2candidate[parameter_id] = parameters + result.append(parameters) + self._pending_result_ids.add(parameter_id) + self._st_callback(parameter_id, parameters) + _logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters)) + return result + + def generate_multiple_parameters(self, parameter_id_list, **kwargs): + """ + Callback function necessary to implement a tuner. This will put more parameter ids into the + parameter id queue. + """ + if "st_callback" in kwargs and self._st_callback is None: + self._st_callback = kwargs["st_callback"] + for parameter_id in parameter_id_list: + self._sending_parameter_queue.append(parameter_id) + self._bind_and_send_parameters() + return [] # always not use this. might induce problem of over-sending + + # def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + # """ + # Callback function. Receive a trial result. + # """ + # _logger.info("Candidate %d, reported reward %f", parameter_id, value) + # self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value + + def receive_trial_result(self): + # 获取并更新self._reward_dict + + file_dir, _, files = next(os.walk(self.acc_path)) + files = [i for i in files if "%03d_"%(self.epoch-1) in i] # self.epoch-1: 读取上一轮的结果 + + acc_dict = {} + for file in files: + with open(os.path.join(file_dir, file), "r") as f: + acc_dict.update(json.load(f)) # {"000_001.json":0.56} + + for key, value in acc_dict.items(): + key = key.lstrip("./checkpoints/") # 删掉路径,仅保留文件名 + self._reward_dict.update({self.id2candidate[key]: value}) # todo {self.id2candidate[key]: key} + + + def trial_end(self, parameter_id, success, **kwargs): + """ + Callback function when a trial is ended and resource is released. + """ + self._pending_result_ids.remove(parameter_id) + if not self._pending_result_ids and not self._to_evaluate_queue: + # a new epoch now + self._next_round() + assert self._st_callback is not None + self._bind_and_send_parameters() + + def export_results(self, result): + """ + Export a number of candidates to `checkpoints` dir. + + Parameters + ---------- + result : dict + Chosen architectures to be exported. + """ + os.makedirs("checkpoints", exist_ok=True) + os.makedirs("id2cand", exist_ok=True) + self.id2candidate = {} + for i, cand in enumerate(result): + converted = dict() + for cand_key, cand_val in cand.items(): + onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))] + converted[cand_key] = onehot + with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp: + json.dump(converted, fp) + + """ + self.id2candidate: + { + 000_000.json: {"LayerChoice1": {"_values":3, "_idx":3}, "LayerChoice2": {"_values":2, "_idx":2}, ...} + ...... + } + """ + self.id2candidate.update({"%03d_%03d.json" % (self.epoch, i): json.dumps(result[i], sort_keys=True)}) + with open("./id2cand/%03d_id2candidate.json"%self.epoch, "w") as f: + json.dump(self.id2candidate, f) + + +class EvolutionWithFlops(SPOSEvolution): + """ + This tuner extends the function of evolution tuner, by limiting the flops generated by tuner. + Needs a function to examine the flops. + """ + + def __init__(self, flops_limit=330E6, **kwargs): + super().__init__(**kwargs) + # self.model = ShuffleNetV2OneShot() + self.flops_limit = flops_limit + + with open(os.path.join(os.path.dirname(__file__), "./data/op_flops_dict.pkl"), "rb") as fp: + self._op_flops_dict = pickle.load(fp) + + def _is_legal(self, cand): + if not super()._is_legal(cand): + return False + if self.get_candidate_flops(cand) > self.flops_limit: + return False + return True + + def get_candidate_flops(self, candidate): + """ + this method is the same with ShuffleNetV2OneShot.get_candidate_flops, but we dont need to initialize that class. + """ + conv1_flops = self._op_flops_dict["conv1"][(3, 16, + 224, 224, 2)] + rest_flops = self._op_flops_dict["rest_operation"][(640, 1000, + 7, 7, 1)] + total_flops = conv1_flops + rest_flops + for k, m in candidate.items(): + parsed_flops_dict = PARSED_FLOPS[k] + if isinstance(m, dict): # to be compatible with classical nas format + total_flops += parsed_flops_dict[m["_idx"]] + else: + total_flops += parsed_flops_dict[torch.max(m, 0)[1]] + return total_flops + \ No newline at end of file diff --git a/dubhe-tadl/spos/evolve.py b/dubhe-tadl/spos/evolve.py new file mode 100644 index 0000000..879de51 --- /dev/null +++ b/dubhe-tadl/spos/evolve.py @@ -0,0 +1,81 @@ +# import os +# debugger_path = os.path.abspath("./") +# os.chdir(debugger_path) + +import json +import argparse +from evolution_tuner import EvolutionWithFlops + +from evaluator import Evaluator +import sys +sys.path.append("../") +sys.path.append("../../") + + +def load_search_space(path="./auto_gen_search_space.json"): + with open(path) as f: + search_space = json.load(f) + return search_space + + +def trial(args, trial_id, search_space): + """ + search the best model by evoluationary algo + """ + + evolution_spos = EvolutionWithFlops(max_epochs=args.max_epoches, + num_select=args.num_select, + num_population=args.num_population, + m_prob=args.m_prob, + num_crossover=args.num_crossover, + num_mutation=args.num_mutation, + epoch=trial_id, + ) + + evolution_spos.update_search_space(search_space) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("search the net by evolution") + parser.add_argument("--search_space_path", type=str, default="./auto_gen_search_space.json") + parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar + parser.add_argument("--num_select", type=int, default=2) # 10 + parser.add_argument("--num_population", type=int, default=4) # 50 + parser.add_argument("--workers", type=int, default=1) # 线程数 + parser.add_argument("--num_crossover", type=int, default=2) # 25 + parser.add_argument("--num_mutation", type=int, default=2) # 25 + parser.add_argument("--max_epoches", type=int, default=3) # 20 + parser.add_argument("--trial_id", type=int, default=1) + parser.add_argument("--m_prob", type=float, default=0.1) + + parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet + parser.add_argument("--spos-preprocessing", default=True, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--train-batch-size", type=int, default=128) + parser.add_argument("--train-iters", type=int, default=200) + parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200 + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval") + + args = parser.parse_args() + + search_space = load_search_space(path=args.search_space_path) + # evl = Evaluator() + + # if args.single_trial: + # epoch = 0 + # print("*" * 50, "\n") + # print("epoch {}{}".format(epoch, "\n")) + # print("*" * 50, "\n") + # trial(args, epoch, search_space=search_space) + # else: + # for epoch in range(2, args.max_epoches+2): + # print("*"*50, "\n") + # print("epoch {}{}".format(epoch, "\n")) + # print("*"*50, "\n") + # trial(args, epoch, search_space=search_space) + + trial(args, args.trial_id, search_space) diff --git a/dubhe-tadl/spos/network.py b/dubhe-tadl/spos/network.py new file mode 100644 index 0000000..37ad037 --- /dev/null +++ b/dubhe-tadl/spos/network.py @@ -0,0 +1,203 @@ +import sys +sys.path.append("../../") + + +import os +import re +import pickle + +import torch +import torch.nn as nn +from pytorch.mutables import LayerChoice + +from blocks import ShuffleNetBlock, ShuffleXceptionBlock + +PARSED_FLOPS = {'LayerChoice1': [13396992, 15805440, 19418112, 13146112], + 'LayerChoice2': [7325696, 8931328, 11339776, 12343296], + 'LayerChoice3': [7325696, 8931328, 11339776, 12343296], + 'LayerChoice4': [7325696, 8931328, 11339776, 12343296], + 'LayerChoice5': [26304768, 28111104, 30820608, 20296192], + 'LayerChoice6': [10599680, 11603200, 13108480, 16746240], + 'LayerChoice7': [10599680, 11603200, 13108480, 16746240], + 'LayerChoice8': [10599680, 11603200, 13108480, 16746240], + 'LayerChoice9': [30670080, 31673600, 33178880, 21199360], + 'LayerChoice10': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice11': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice12': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice13': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice14': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice15': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice16': [10317440, 10819200, 11571840, 15899520], + 'LayerChoice17': [30387840, 30889600, 31642240, 20634880], + 'LayerChoice18': [10176320, 10427200, 10803520, 15476160], + 'LayerChoice19': [10176320, 10427200, 10803520, 15476160], + 'LayerChoice20': [10176320, 10427200, 10803520, 15476160]} + + +class ShuffleNetV2OneShot(nn.Module): + block_keys = [ + 'shufflenet_3x3', + 'shufflenet_5x5', + 'shufflenet_7x7', + 'xception_3x3', + ] + + def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000, + op_flops_path="./data/op_flops_dict.pkl", affine=False): + super().__init__() + + assert input_size % 32 == 0 + with open(os.path.join(os.path.dirname(__file__), op_flops_path), "rb") as fp: + self._op_flops_dict = pickle.load(fp) + + self.stage_blocks = [4, 4, 8, 4] + self.stage_channels = [64, 160, 320, 640] + self._parsed_flops = dict() + self._input_size = input_size + self._feature_map_size = input_size + self._first_conv_channels = first_conv_channels + self._last_conv_channels = last_conv_channels + self._n_classes = n_classes + self._affine = affine + + # building first layer + self.first_conv = nn.Sequential( + nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(first_conv_channels, affine=affine), + nn.ReLU(inplace=True), + ) + self._feature_map_size //= 2 + + p_channels = first_conv_channels + features = [] + for num_blocks, channels in zip(self.stage_blocks, self.stage_channels): + features.extend(self._make_blocks(num_blocks, p_channels, channels)) + p_channels = channels + self.features = nn.Sequential(*features) + + self.conv_last = nn.Sequential( + nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(last_conv_channels, affine=affine), + nn.ReLU(inplace=True), + ) + self.globalpool = nn.AvgPool2d(self._feature_map_size) + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Sequential( + nn.Linear(last_conv_channels, n_classes, bias=False), + ) + + self._initialize_weights() + + def _make_blocks(self, blocks, in_channels, channels): + result = [] + for i in range(blocks): + stride = 2 if i == 0 else 1 + inp = in_channels if i == 0 else channels + oup = channels + + base_mid_channels = channels // 2 + mid_channels = int(base_mid_channels) # prepare for scale + choice_block = LayerChoice([ + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), + ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), + ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) + ]) + result.append(choice_block) + + # find the corresponding flops + flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride) + self._parsed_flops[choice_block.key] = [ + self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys + ] + + if stride == 2: + self._feature_map_size //= 2 + + # ##### mended by han ################### + # 通过mutables.LayerChoice生成的choice_block会不断的更新choice_block.key编号,每次自增1, + # 这样会使self._parsed_flops的键编号超过20,这样的键是不存在的 + # 出于所有算法共用一个mutable的原因,不在其中对 + # global_mutable_counting() + # _reset_global_mutable_counting() + # 两个函数进行调用或修改,因此在此需要对self.parsed_flops的键重命名 + _d = dict() + for key, value in self._parsed_flops.items(): + _head = key[:11] # LayerChoice + _index = int(key[11:]) % 20 # 模20,因为choiceblock共有20个,需要保证编号出于0-20 + if _index == 0: + _index = 20 # 模20为0的索引,事实上应该是20 + _d.update({_head + str(_index): value}) + + self._parsed_flops = _d + # ####################################### + return result + + def forward(self, x): + bs = x.size(0) + x = self.first_conv(x) + x = self.features(x) + x = self.conv_last(x) + x = self.globalpool(x) + + x = self.dropout(x) + x = x.contiguous().view(bs, -1) + x = self.classifier(x) + return x + + def get_candidate_flops(self, candidate): + conv1_flops = self._op_flops_dict["conv1"][(3, self._first_conv_channels, + self._input_size, self._input_size, 2)] + # Should use `last_conv_channels` here, but megvii insists that it's `n_classes`. Keeping it. + # https://github.com/megvii-model/SinglePathOneShot/blob/36eed6cf083497ffa9cfe7b8da25bb0b6ba5a452/src/Supernet/flops.py#L313 + rest_flops = self._op_flops_dict["rest_operation"][(self.stage_channels[-1], self._n_classes, + self._feature_map_size, self._feature_map_size, 1)] + total_flops = conv1_flops + rest_flops + for k, m in candidate.items(): + parsed_flops_dict = self._parsed_flops[k] + if isinstance(m, dict): # to be compatible with classical nas format + total_flops += parsed_flops_dict[m["_idx"]] + else: + total_flops += parsed_flops_dict[torch.max(m, 0)[1]] + return total_flops + + def _initialize_weights(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + if 'first' in name: + nn.init.normal_(m.weight, 0, 0.01) + else: + nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0001) + nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0001) + nn.init.constant_(m.running_mean, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"): + checkpoint = torch.load(filepath, map_location=torch.device("cpu")) + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + result = dict() + for k, v in checkpoint.items(): + if k.startswith("module."): + k = k[len("module."):] + result[k] = v + return result + + +if __name__ == "__main__": + model = ShuffleNetV2OneShot() \ No newline at end of file diff --git a/dubhe-tadl/spos/readme_spos.md b/dubhe-tadl/spos/readme_spos.md new file mode 100644 index 0000000..d60e68e --- /dev/null +++ b/dubhe-tadl/spos/readme_spos.md @@ -0,0 +1,68 @@ +# Single Path One-Shot(SPOS) +## **简介** +该方法由[Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420) +中提出,主体思想可以分为两个部分,分别是Single Path和One-shot。其中One-Shot指,前期训练一个超网络, +后期对超网络不断进行采样或剪枝等等的方法来获得最终的子网络。而Single Path指,在对于训练好的超网络,每一个模型都是超网络的一条路径。 +该算法整体来看即:将网络的层级结构视为一条路径,路径的节点即每个神经层,每个节点有多种选择(多种神经层),对每个节点进行采样得到一个确定的神经层, +并连接每个节点成为一个路径,该路径即最终采样得到的子网络。 + +本实例参照microsoft nni中的spos repo实现了spos的超网训练、子网络的进化搜索、最终选取网络的重训练。 + +## 使用介绍 + +- 模型的训练用到了NVIDIA dali工具,需要提前[安装](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) +- 模型的训练使用imagenet数据集,需要提前准备 +- 模型的flops计算需要用到一个flops查找表,可以在[megvii](https://onedrive.live.com/?authkey=%21ADesvSdfsq%5FcN48&id=E7CA2ABE6D98E66F%21106&cid=E7CA2ABE6D98E66F) +下载。同时这里还可以下载到官方提供的supernet模型,以及最终重训练的模型等等。 + +### **目录结构** +可以将imagenet数据放在```./data```目录下,标准的数据处理方式可以参考[这里](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4) + +imagenet文件准备好之后,训练集和测试集应分别包含1000个子文件夹。 + +将文件准备齐全之后,目录结构应类似如下: +``` +spos +├── architecture_final.json +├── blocks.py +├── config_search.yml +├── data +│ ├── imagenet +│ │ ├── train +│ │ └── val +│ └── op_flops_dict.pkl +├── dataloader.py +├── network.py +├── readme.md +├── scratch.py +├── supernet.py +├── tester.py +├── evolution_tuner.py +└── utils.py +``` + +### **超网络的训练** +```python supernet.py``` +- 如果不需要训练整个超网络,可以试用上述地址中下载的supernet网络,并将其放在```./data```目录下 +- 训练完成之后,checkpoint会到处在```./checkpoints```路径下 +— 为了和[官方repo](https://github.com/megvii-model/SinglePathOneShot) 保持一致,数据的通道使用BGR模式,同时数据的输入范围保持在[0,255]. + +### **子网络的进化搜索** +首先准备搜索空间 + +```python tester.py --mode gen``` + +然后进行基于进化算法的搜索 + +```python search.py``` +- 每次进化都会选出若干最优,其数目定义在dali_loader.py中,最终的准确率保存在```./acc```,路径下 +- 进化的模型结构(仅包含结构的json文件)保存在```./checkpoints```路径下 +- 模型结构的映射关系保存在```./id2cand```路径下 + +### **最终模型的重训练** +```python scartch.py``` + + + + +today \ No newline at end of file diff --git a/dubhe-tadl/spos/runs/events.out.tfevents.1614222519.qjy-ai10.44179.0 b/dubhe-tadl/spos/runs/events.out.tfevents.1614222519.qjy-ai10.44179.0 new file mode 100644 index 0000000..9d6ce8f Binary files /dev/null and b/dubhe-tadl/spos/runs/events.out.tfevents.1614222519.qjy-ai10.44179.0 differ diff --git a/dubhe-tadl/spos/scratch.py b/dubhe-tadl/spos/scratch.py new file mode 100644 index 0000000..d914064 --- /dev/null +++ b/dubhe-tadl/spos/scratch.py @@ -0,0 +1,158 @@ +import os +import argparse +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +from dataloader import get_imagenet_iter_dali +from pytorch.fixed import apply_fixed_architecture +from pytorch.utils import AverageMeterGroup +from torch.utils.tensorboard import SummaryWriter +# import torch.distributed as dist + +from network import ShuffleNetV2OneShot +from utils import CrossEntropyLabelSmooth, accuracy + +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +logger = logging.getLogger("nni.spos.scratch") + + +def train(epoch, model, criterion, optimizer, loader, writer, args): + model.train() + meters = AverageMeterGroup() + cur_lr = optimizer.param_groups[0]["lr"] + + for step, (x, y) in enumerate(loader): + cur_step = len(loader) * epoch + step + optimizer.zero_grad() + logits = model(x) + loss = criterion(logits, y) + loss.backward() + optimizer.step() + + metrics = accuracy(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + + writer.add_scalar("lr", cur_lr, global_step=cur_step) + writer.add_scalar("loss/train", loss.item(), global_step=cur_step) + writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step) + writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step) + + if step % args.log_frequency == 0 or step + 1 == len(loader): + logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, + args.epochs, step + 1, len(loader), meters) + + if step > len(loader): + break + + logger.info("Epoch %d training summary: %s", epoch + 1, meters) + + +def validate(epoch, model, criterion, loader, writer, args): + model.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (x, y) in enumerate(loader): + logits = model(x) + loss = criterion(logits, y) + metrics = accuracy(logits, y) + metrics["loss"] = loss.item() + meters.update(metrics) + + if step % args.log_frequency == 0 or step + 1 == len(loader): + logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1, + args.epochs, step + 1, len(loader), meters) + + if step > len(loader): + break + + writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch) + writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch) + writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch) + + logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg) + + +def dump_checkpoint(model, epoch, checkpoint_dir): + if isinstance(model, nn.DataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + dest_path = os.path.join(checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) + logger.info("Saving model to %s", dest_path) + torch.save(state_dict, dest_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Training From Scratch") + parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet + parser.add_argument("--tb-dir", type=str, default="runs") + parser.add_argument("--architecture", type=str, default="./checkpoints/037_034.json") # "architecture_final.json" + parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=1024) + parser.add_argument("--epochs", type=int, default=240) + parser.add_argument("--learning-rate", type=float, default=0.5) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--weight-decay", type=float, default=4E-5) + parser.add_argument("--label-smooth", type=float, default=0.1) + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--lr-decay", type=str, default="linear") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--spos-preprocessing", default=False, action="store_true") + parser.add_argument("--label-smoothing", type=float, default=0.1) + parser.add_argument("--local_rank", default=[0,1,2,3]) + + args = parser.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + model = ShuffleNetV2OneShot(affine=True) + model.cuda("cuda:0") + apply_fixed_architecture(model, args.architecture) + + # state_dict是否发生变化 + state_dict = model.state_dict() + + # todo DDP并行的一些设置 + # dist.init_process_group(backend = "nccl") + + if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu + model = nn.DataParallel(model, + device_ids=list(range(0, torch.cuda.device_count() - 1))) # todo # device_ids=list(range(0, torch.cuda.device_count() - 1)) + # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.local_rank) + criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, + momentum=args.momentum, weight_decay=args.weight_decay) + if args.lr_decay == "linear": + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda step: (1.0 - step / args.epochs) + if step <= args.epochs else 0, + last_epoch=-1) + elif args.lr_decay == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3) + else: + raise ValueError("'%s' not supported." % args.lr_decay) + writer = SummaryWriter(log_dir=args.tb_dir) + + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing) + + for epoch in range(args.epochs): + train(epoch, model, criterion, optimizer, train_loader, writer, args) + validate(epoch, model, criterion, val_loader, writer, args) + scheduler.step() + dump_checkpoint(model, epoch, "scratch_checkpoints") + + writer.close() diff --git a/dubhe-tadl/spos/supernet.py b/dubhe-tadl/spos/supernet.py new file mode 100644 index 0000000..c7cb450 --- /dev/null +++ b/dubhe-tadl/spos/supernet.py @@ -0,0 +1,80 @@ +import argparse +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +from pytorch.callbacks import LRSchedulerCallback +from pytorch.callbacks import ModelCheckpoint +from algorithms.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer + +from dataloader import get_imagenet_iter_dali +from network import ShuffleNetV2OneShot, load_and_parse_state_dict +from utils import CrossEntropyLabelSmooth, accuracy + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +logger = logging.getLogger("nni.spos.supernet") + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Supernet Training") + + # 数据的路径需要修改,由于home的容量较小,数据存储在local下面 + # default="./data/imagenet" + parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/imagenet") + parser.add_argument("--load-checkpoint", action="store_true", default=False) + parser.add_argument("--spos-preprocessing", action="store_true", default=False, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") + parser.add_argument("--workers", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=512) # 原始大小为768 + parser.add_argument("--epochs", type=int, default=120) # 原始大小是120 + parser.add_argument("--learning-rate", type=float, default=0.5) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--weight-decay", type=float, default=4E-5) + parser.add_argument("--label-smooth", type=float, default=0.1) + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--label-smoothing", type=float, default=0.1) + + args = parser.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + model = ShuffleNetV2OneShot() + flops_func = model.get_candidate_flops + if args.load_checkpoint: + if not args.spos_preprocessing: + logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.") + model.load_state_dict(load_and_parse_state_dict()) + model.cuda() + if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu + model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) + mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func, + flops_lb=290E6, flops_ub=360E6) + criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, + momentum=args.momentum, weight_decay=args.weight_decay) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, + lambda step: (1.0 - step / args.epochs) + if step <= args.epochs else 0, + last_epoch=-1) + + device_id = 3 + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing,device_id=device_id) + valid_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing,device_id=device_id) + trainer = SPOSSupernetTrainer(model, criterion, accuracy, optimizer, + args.epochs, train_loader, valid_loader, + mutator=mutator, batch_size=args.batch_size, + log_frequency=args.log_frequency, workers=args.workers, + callbacks=[LRSchedulerCallback(scheduler), + ModelCheckpoint("./checkpoints")]) + trainer.train() diff --git a/dubhe-tadl/spos/tester.py b/dubhe-tadl/spos/tester.py new file mode 100644 index 0000000..522f6c8 --- /dev/null +++ b/dubhe-tadl/spos/tester.py @@ -0,0 +1,169 @@ +import argparse +# import logging +import random +import time +from itertools import cycle +import json + + +import sys +sys.path.insert(0,"/home/hanjiayi/Document/nni") +sys.path.insert(0,"/home/hanjiayi/Document/nni/examples") +sys.path.insert(0,"/home/hanjiayi/Document/nni/nni/algorithms") +import os + +# 指定写入文件的环境变量,原实现于nnictl_utils.py的search_space_auto_gen +os.environ["NNI_GEN_SEARCH_SPACE"] = "auto_gen_search_space.json" + +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" + + +import numpy as np +import torch +import torch.nn as nn +from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture +from nni.nas.pytorch.fixed import apply_fixed_architecture +from nni.nas.pytorch.utils import AverageMeterGroup + +from dataloader import get_imagenet_iter_dali +from network import ShuffleNetV2OneShot, load_and_parse_state_dict +from utils import CrossEntropyLabelSmooth, accuracy + + +# logger = logging.getLogger("nni.spos.tester") # "nni.spos.tester" +print("Evolution Beginning...") + +def retrain_bn(model, criterion, max_iters, log_freq, loader): + with torch.no_grad(): + # logger.info("Clear BN statistics...") + print("clear BN statistics") + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.running_mean = torch.zeros_like(m.running_mean) + m.running_var = torch.ones_like(m.running_var) + + # logger.info("Train BN with training set (BN sanitize)...") + print("Train BN with training set (BN sanitize)...") + model.train() + meters = AverageMeterGroup() + start_time = time.time() + + for step in range(max_iters): + inputs, targets = next(loader) + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == max_iters: + # logger.info("Train Step [%d/%d] %s time %.3fs ", step + 1, max_iters, meters, time.time() - start_time) + print("Train Step [%d/%d] %s time %.3fs "% (step + 1, max_iters, meters, time.time() - start_time)) + +def test_acc(model, criterion, log_freq, loader): + # logger.info("Start testing...") + print("start testing...") + model.eval() + meters = AverageMeterGroup() + start_time = time.time() + + with torch.no_grad(): + for step, (inputs, targets) in enumerate(loader): + logits = model(inputs) + loss = criterion(logits, targets) + metrics = accuracy(logits, targets) + metrics["loss"] = loss.item() + meters.update(metrics) + if step % log_freq == 0 or step + 1 == len(loader): + # logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f", + # step + 1, len(loader), time.time() - start_time, + # meters.acc1.avg, meters.acc5.avg, meters.loss.avg) + print("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f"% + (step + 1, len(loader), time.time() - start_time, + meters.acc1.avg, meters.acc5.avg, meters.loss.avg)) + if step>len(loader): # 遍历一遍就停止 + break + return meters.acc1.avg + + +def evaluate_acc(model, criterion, args, loader_train, loader_test): + + retrain_bn(model, criterion, args.train_iters, args.log_frequency, loader_train) # todo + acc = test_acc(model, criterion, args.log_frequency, loader_test) + assert isinstance(acc, float) + torch.cuda.empty_cache() + return acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("SPOS Candidate Tester") + parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet + parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar + parser.add_argument("--spos-preprocessing", default=True, + help="When true, image values will range from 0 to 255 and use BGR " + "(as in original repo).") # , action="store_true" + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--workers", type=int, default=6) # 线程数 + parser.add_argument("--train-batch-size", type=int, default=128) + parser.add_argument("--train-iters", type=int, default=200) + parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200 + parser.add_argument("--log-frequency", type=int, default=10) + parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval") + parser.add_argument("--mode", type=str, default="gen", help="there are two modes here: gen mode for generating architecture, and evl mode for evaluation model") + + args = parser.parse_args() + + # use a fixed set of image will improve the performance + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + torch.backends.cudnn.deterministic = True + + assert torch.cuda.is_available() + + model = ShuffleNetV2OneShot() + criterion = CrossEntropyLabelSmooth(1000, 0.1) + + + if args.mode == "gen": + get_and_apply_next_architecture(model) + model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint)) + + else: # evaluate the model + print("## test&retrain -- load model ## begin to load model") + model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint)) + print("## test&retrain -- load model ## model loaded") + + print("## test&retrain -- apply architecture ## begin to apply architecture to model") + apply_fixed_architecture(model, args.architecture) + print("## test&retrain -- apply architecture ## architecture applied") + + model.cuda(0) + print("## test&retrain -- load train data ## begin to load train data") + train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.train_batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing, + seed=args.seed, device_id=0) + print("## test&retrain -- load train data ## train data loaded") + + print("## test&retrain -- load test data ## begin to load test data") + val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.test_batch_size, args.workers, + spos_preprocessing=args.spos_preprocessing, shuffle=True, + seed=args.seed, device_id=0) + print("## test&retrain -- load test date ## test data loaded") + + train_loader = cycle(train_loader) + acc = evaluate_acc(model, criterion, args, train_loader, val_loader) + + # 把模型最终的准确率写入一个文件中 + os.makedirs("./acc", exist_ok=True) + with open("./acc/{}".format(args.architecture[-12:]), "w") as f: # [-12:] 代表没有路径的文件名 + # {filename1: acc, + # filename2: acc, + # 000_000.json: acc, + # 000_001.json: acc, + # ...... + # } + json.dump({args.architecture: acc}, f) + + + diff --git a/dubhe-tadl/spos/utils.py b/dubhe-tadl/spos/utils.py new file mode 100644 index 0000000..24ef342 --- /dev/null +++ b/dubhe-tadl/spos/utils.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + + +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + # todo , device="cuda:6" + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + + +def accuracy(output, target, topk=(1, 5)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + # correct_k = correct[:k].view(-1).float().sum(0) # 原始结果 + correct_k = correct[:k].reshape(-1).float().sum(0) # .view(-1)不支持 + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res diff --git a/dubhe-tadl/textnas/dataloader.py b/dubhe-tadl/textnas/dataloader.py new file mode 100644 index 0000000..88ab357 --- /dev/null +++ b/dubhe-tadl/textnas/dataloader.py @@ -0,0 +1,337 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import os +import pickle +from collections import Counter + +import numpy as np +import torch +from torch.utils import data +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class PTBTree: + WORD_TO_WORD_MAPPING = { + "{": "-LCB-", + "}": "-RCB-" + } + + def __init__(self): + self.subtrees = [] + self.word = None + self.label = "" + self.parent = None + self.span = (-1, -1) + self.word_vector = None # HOS, store dx1 RNN word vector + self.prediction = None # HOS, store Kx1 prediction vector + + def is_leaf(self): + return len(self.subtrees) == 0 + + def set_by_text(self, text, pos=0, left=0): + depth = 0 + right = left + for i in range(pos + 1, len(text)): + char = text[i] + # update the depth + if char == "(": + depth += 1 + if depth == 1: + subtree = PTBTree() + subtree.parent = self + subtree.set_by_text(text, i, right) + right = subtree.span[1] + self.span = (left, right) + self.subtrees.append(subtree) + elif char == ")": + depth -= 1 + if len(self.subtrees) == 0: + pos = i + for j in range(i, 0, -1): + if text[j] == " ": + pos = j + break + self.word = text[pos + 1:i] + self.span = (left, left + 1) + + # we've reached the end of the category that is the root of this subtree + if depth == 0 and char == " " and self.label == "": + self.label = text[pos + 1:i] + # we've reached the end of the scope for this bracket + if depth < 0: + break + + # Fix some issues with variation in output, and one error in the treebank + # for a word with a punctuation POS + self.standardise_node() + + def standardise_node(self): + if self.word in self.WORD_TO_WORD_MAPPING: + self.word = self.WORD_TO_WORD_MAPPING[self.word] + + def __repr__(self, single_line=True, depth=0): + ans = "" + if not single_line and depth > 0: + ans = "\n" + depth * "\t" + ans += "(" + self.label + if self.word is not None: + ans += " " + self.word + for subtree in self.subtrees: + if single_line: + ans += " " + ans += subtree.__repr__(single_line, depth + 1) + ans += ")" + return ans + + +def read_tree(source): + cur_text = [] + depth = 0 + while True: + line = source.readline() + # Check if we are out of input + if line == "": + return None + # strip whitespace and only use if this contains something + line = line.strip() + if line == "": + continue + cur_text.append(line) + # Update depth + for char in line: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + # At depth 0 we have a complete tree + if depth == 0: + tree = PTBTree() + tree.set_by_text(" ".join(cur_text)) + return tree + return None + + +def read_trees(source, max_sents=-1): + with open(source) as fp: + trees = [] + while True: + tree = read_tree(fp) + if tree is None: + break + trees.append(tree) + if len(trees) >= max_sents > 0: + break + return trees + + +class SSTDataset(data.Dataset): + def __init__(self, sents, mask, labels): + self.sents = sents + self.labels = labels + self.mask = mask + + def __getitem__(self, index): + return (self.sents[index], self.mask[index]), self.labels[index] + + def __len__(self): + return len(self.sents) + + +def sst_get_id_input(content, word_id_dict, max_input_length): + words = content.split(" ") + sentence = [word_id_dict[""]] * max_input_length + mask = [0] * max_input_length + unknown = word_id_dict[""] + for i, word in enumerate(words[:max_input_length]): + sentence[i] = word_id_dict.get(word, unknown) + mask[i] = 1 + return sentence, mask + + +def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False): + all_phrases = [] + for tree in trees: + if only_sentence: + sentence = get_sentence_by_tree(tree) + label = int(tree.label) + pair = (sentence, label) + all_phrases.append(pair) + else: + phrases = get_phrases_by_tree(tree) + sentence = get_sentence_by_tree(tree) + pair = (sentence, int(tree.label)) + all_phrases.append(pair) + all_phrases += phrases + if sample_ratio < 1.: + np.random.shuffle(all_phrases) + result_phrases = [] + for pair in all_phrases: + if is_binary: + phrase, label = pair + if label <= 1: + pair = (phrase, 0) + elif label >= 3: + pair = (phrase, 1) + else: + continue + if sample_ratio == 1.: + result_phrases.append(pair) + else: + rand_portion = np.random.random() + if rand_portion < sample_ratio: + result_phrases.append(pair) + return result_phrases + + +def get_phrases_by_tree(tree): + phrases = [] + if tree is None: + return phrases + if tree.is_leaf(): + pair = (tree.word, int(tree.label)) + phrases.append(pair) + return phrases + left_child_phrases = get_phrases_by_tree(tree.subtrees[0]) + right_child_phrases = get_phrases_by_tree(tree.subtrees[1]) + phrases.extend(left_child_phrases) + phrases.extend(right_child_phrases) + sentence = get_sentence_by_tree(tree) + pair = (sentence, int(tree.label)) + phrases.append(pair) + return phrases + + +def get_sentence_by_tree(tree): + if tree is None: + return "" + if tree.is_leaf(): + return tree.word + left_sentence = get_sentence_by_tree(tree.subtrees[0]) + right_sentence = get_sentence_by_tree(tree.subtrees[1]) + sentence = left_sentence + " " + right_sentence + return sentence.strip() + + +def get_word_id_dict(word_num_dict, word_id_dict, min_count): + z = [k for k in sorted(word_num_dict.keys())] + for word in z: + count = word_num_dict[word] + if count >= min_count: + index = len(word_id_dict) + if word not in word_id_dict: + word_id_dict[word] = index + return word_id_dict + + +def load_word_num_dict(phrases, word_num_dict): + for sentence, _ in phrases: + words = sentence.split(" ") + for cur_word in words: + word = cur_word.strip() + word_num_dict[word] += 1 + return word_num_dict + + +def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300): + word_embed_model = load_glove_model(embedding_path, embed_dim) + assert word_embed_model["pool"].shape[1] == embed_dim + embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25 + embedding[0] = np.zeros(embed_dim) # PAD + embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK + for word in sorted(word_id_dict.keys()): + idx = word_id_dict[word] + if idx == 0 or idx == 1: + continue + if word in word_embed_model["mapping"]: + embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]] + else: + embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25 + return embedding + + +def sst_get_trainable_data(phrases, word_id_dict, max_input_length): + texts, labels, mask = [], [], [] + + for phrase, label in phrases: + if not phrase.split(): + continue + phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length) + texts.append(phrase_split) + labels.append(int(label)) + mask.append(mask_split) # field_input is mask + labels = np.array(labels, dtype=np.int64) + texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32) + mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32) + + return SSTDataset(texts, mask, labels) + + +def load_glove_model(filename, embed_dim): + if os.path.exists(filename + ".cache"): + logger.info("Found cache. Loading...") + with open(filename + ".cache", "rb") as fp: + return pickle.load(fp) + embedding = {"mapping": dict(), "pool": []} + with open(filename) as f: + for i, line in enumerate(f): + line = line.rstrip("\n") + vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim) + assert len(vec) == 300, "Unexpected line: '%s'" % line + embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32)) + embedding["mapping"][vocab_word] = i + embedding["pool"] = np.stack(embedding["pool"]) + with open(filename + ".cache", "wb") as fp: + pickle.dump(embedding, fp) + return embedding + + +def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False, + train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False): + + logger.info("data path: {}".format(os.getcwd())) + word_id_dict = dict() + word_num_dict = Counter() + + sst_path = os.path.join(data_path, "sst") + logger.info("Reading SST data...") + train_file_name = os.path.join(sst_path, "trees", "train.txt") + valid_file_name = os.path.join(sst_path, "trees", "dev.txt") + test_file_name = os.path.join(sst_path, "trees", "test.txt") + train_trees = read_trees(train_file_name) + train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence) + logger.info("Finish load train phrases.") + valid_trees = read_trees(valid_file_name) + valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence) + logger.info("Finish load valid phrases.") + if train_with_valid: + train_phrases += valid_phrases + test_trees = read_trees(test_file_name) + test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True) + logger.info("Finish load test phrases.") + + # get word_id_dict + word_id_dict[""] = 0 + word_id_dict[""] = 1 + load_word_num_dict(train_phrases, word_num_dict) + logger.info("Finish load train words: %d.", len(word_num_dict)) + load_word_num_dict(valid_phrases, word_num_dict) + load_word_num_dict(test_phrases, word_num_dict) + logger.info("Finish load valid+test words: %d.", len(word_num_dict)) + word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count) + logger.info("After trim vocab length: %d.", len(word_id_dict)) + + logger.info("Loading embedding...") + embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict) + logger.info("Finish initialize word embedding.") + + dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length) + logger.info("Loaded %d training samples.", len(dataset_train)) + dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length) + logger.info("Loaded %d validation samples.", len(dataset_valid)) + dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length) + logger.info("Loaded %d test samples.", len(dataset_test)) + + return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding) diff --git a/dubhe-tadl/textnas/model.py b/dubhe-tadl/textnas/model.py new file mode 100644 index 0000000..c4ea695 --- /dev/null +++ b/dubhe-tadl/textnas/model.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import torch +import torch.nn as nn +from pytorch import mutables + +from ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm +from utils import GlobalMaxPool, GlobalAvgPool + + +class Layer(mutables.MutableScope): + def __init__(self, + key, + prev_keys, + hidden_units, + choose_from_k, + cnn_keep_prob, + lstm_keep_prob, + att_keep_prob, + att_mask): + + super(Layer, self).__init__(key) + + def conv_shortcut(kernel_size): + return ConvBN(kernel_size, hidden_units, hidden_units, cnn_keep_prob, False, True) + + self.n_candidates = len(prev_keys) + if self.n_candidates: + self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1) + else: + # first layer, skip input choice + self.prec = None + self.op = mutables.LayerChoice([ + conv_shortcut(1), + conv_shortcut(3), + conv_shortcut(5), + conv_shortcut(7), + AvgPool(3, False, True), + MaxPool(3, False, True), + RNN(hidden_units, lstm_keep_prob), + Attention(hidden_units, 4, att_keep_prob, att_mask) + ]) + if self.n_candidates: + self.skipconnect = mutables.InputChoice(choose_from=prev_keys) + else: + self.skipconnect = None + self.bn = BatchNorm(hidden_units, False, True) + + def forward(self, last_layer, prev_layers, mask): + # pass an extra last_layer to deal with layer 0 (prev_layers is empty) + if self.prec is None: + prec = last_layer + else: + prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first + out = self.op(prec, mask) + if self.skipconnect is not None: + connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:]) + if connection is not None: + out += connection + out = self.bn(out, mask) + return out + + +class Model(nn.Module): + def __init__(self, embedding, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5, + lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True, + embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"): + super(Model, self).__init__() + + # load word embedding + self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False) + self.hidden_units = hidden_units + self.num_layers = num_layers + self.num_classes = num_classes + # 第一层 + self.init_conv = ConvBN(1, self.embedding.embedding_dim, hidden_units, cnn_keep_prob, False, True) + + self.layers = nn.ModuleList() + candidate_keys_pool = [] # ['layer_0', 'layer_1'] + for layer_id in range(self.num_layers): + k = "layer_{}".format(layer_id) + self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k, + cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask)) + candidate_keys_pool.append(k) + + self.linear_combine = LinearCombine(self.num_layers) + self.linear_out = nn.Linear(self.hidden_units, self.num_classes) + + self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob) + self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob) + + assert global_pool in ["max", "avg"] + if global_pool == "max": + self.global_pool = GlobalMaxPool() + elif global_pool == "avg": + self.global_pool = GlobalAvgPool() + + def forward(self, inputs): + sent_ids, mask = inputs + seq = self.embedding(sent_ids.long()) + seq = self.embed_dropout(seq) + + seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L) + # from (batch_size, seq_len, feat_size) -> (batch_size, feat_size, seq_len) + + x = self.init_conv(seq, mask) + prev_layers = [] + + for layer in self.layers: + x = layer(x, prev_layers, mask) + prev_layers.append(x) + + x = self.linear_combine(torch.stack(prev_layers)) + x = self.global_pool(x, mask) + x = self.output_dropout(x) + x = self.linear_out(x) + return x + + + + + diff --git a/dubhe-tadl/textnas/ops.py b/dubhe-tadl/textnas/ops.py new file mode 100644 index 0000000..25e1ef1 --- /dev/null +++ b/dubhe-tadl/textnas/ops.py @@ -0,0 +1,228 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn.functional as F +from torch import nn + +from utils import get_length, INF + + +class Mask(nn.Module): + def forward(self, seq, mask): + # seq: (N, C, L) + # mask: (N, L) + seq_mask = torch.unsqueeze(mask, 2) + seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2) + return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq)) + + + def __str__(self): + return 'Mask' + + +class BatchNorm(nn.Module): + def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True): + super(BatchNorm, self).__init__() + self.mask_opt = Mask() + self.pre_mask = pre_mask + self.post_mask = post_mask + self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine) + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.bn(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + return seq + + def __str__(self): + return 'BatchNorm' + +class ConvBN(nn.Module): + def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob, + pre_mask, post_mask, with_bn=True, with_relu=True): + super(ConvBN, self).__init__() + self.mask_opt = Mask() + self.pre_mask = pre_mask + self.post_mask = post_mask + self.with_bn = with_bn + self.with_relu = with_relu + self.kernal_size = kernal_size + self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2) + self.dropout = nn.Dropout(p=(1 - cnn_keep_prob)) + + if with_bn: + self.bn = BatchNorm(out_channels, not post_mask, True) + + if with_relu: + self.relu = nn.ReLU() + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.conv(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + if self.with_bn: + seq = self.bn(seq, mask) + if self.with_relu: + seq = self.relu(seq) + seq = self.dropout(seq) + return seq + + def __str__(self): + return 'ConvBN_{}'.format(self.kernal_size) + + +class AvgPool(nn.Module): + def __init__(self, kernal_size, pre_mask, post_mask): + super(AvgPool, self).__init__() + self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) + self.pre_mask = pre_mask + self.post_mask = post_mask + self.mask_opt = Mask() + self.kernal_size = kernal_size + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.avg_pool(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + return seq + + def __str__(self): + return 'AvgPool{}'.format(self.kernal_size) + + +class MaxPool(nn.Module): + def __init__(self, kernal_size, pre_mask, post_mask): + super(MaxPool, self).__init__() + self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2) + self.pre_mask = pre_mask + self.post_mask = post_mask + self.mask_opt = Mask() + self.kernel_size = kernal_size + + def forward(self, seq, mask): + if self.pre_mask: + seq = self.mask_opt(seq, mask) + seq = self.max_pool(seq) + if self.post_mask: + seq = self.mask_opt(seq, mask) + return seq + + def __str__(self): + return 'MaxPool{}'.format(self.kernel_size) + + +class Attention(nn.Module): + def __init__(self, num_units, num_heads, keep_prob, is_mask): + super(Attention, self).__init__() + self.num_heads = num_heads + self.keep_prob = keep_prob + + self.linear_q = nn.Linear(num_units, num_units) + self.linear_k = nn.Linear(num_units, num_units) + self.linear_v = nn.Linear(num_units, num_units) + + self.bn = BatchNorm(num_units, True, is_mask) + self.dropout = nn.Dropout(p=1 - self.keep_prob) + + def forward(self, seq, mask): + in_c = seq.size()[1] + seq = torch.transpose(seq, 1, 2) # (N, L, C) + queries = seq + keys = seq + num_heads = self.num_heads + + # T_q = T_k = L + Q = F.relu(self.linear_q(seq)) # (N, T_q, C) + K = F.relu(self.linear_k(seq)) # (N, T_k, C) + V = F.relu(self.linear_v(seq)) # (N, T_k, C) + + # Split and concat + Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h) + K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) + V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h) + + # Multiplication + outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k) + # Scale + outputs = outputs / (K_.size()[-1] ** 0.5) + # Key Masking + key_masks = mask.repeat(num_heads, 1) # (h*N, T_k) + key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k) + key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k) + + paddings = torch.ones_like(outputs) * (-INF) # extremely small value + outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs) + + query_masks = mask.repeat(num_heads, 1) # (h*N, T_q) + query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1) + query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k) + + att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k) + att_scores = self.dropout(att_scores) + + # Weighted sum + x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h) + # Restore shape + x_outputs = torch.cat( + torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0), + dim=2) # (N, T_q, C) + + x = torch.transpose(x_outputs, 1, 2) # (N, C, L) + x = self.bn(x, mask) + + return x + + def __str__(self): + return 'Attention' + + +class RNN(nn.Module): + def __init__(self, hidden_size, output_keep_prob): + super(RNN, self).__init__() + self.hidden_size = hidden_size + self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_keep_prob = output_keep_prob + + self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob)) + + def forward(self, seq, mask): + # seq: (N, C, L) + # mask: (N, L) + max_len = seq.size()[2] + length = get_length(mask) + seq = torch.transpose(seq, 1, 2) # to (N, L, C) + packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True, + enforce_sorted=False) + outputs, _ = self.bid_rnn(packed_seq) + outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, + total_length=max_len)[0] + outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C) + outputs = self.out_dropout(outputs) # output dropout + return torch.transpose(outputs, 1, 2) # back to: (N, C, L) + + def __str__(self): + return 'RNN' + + +class LinearCombine(nn.Module): + def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False): + super(LinearCombine, self).__init__() + self.input_aware = input_aware + self.word_level = word_level + + if input_aware: + raise NotImplementedError("Input aware is not supported.") + self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num), + requires_grad=trainable) + + def forward(self, seq): + nw = F.softmax(self.w, dim=0) + seq = torch.mul(seq, nw) + seq = torch.sum(seq, dim=0) + return seq diff --git a/dubhe-tadl/textnas/retrain.py b/dubhe-tadl/textnas/retrain.py new file mode 100644 index 0000000..152889c --- /dev/null +++ b/dubhe-tadl/textnas/retrain.py @@ -0,0 +1,536 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys +import os +import logging +import pickle +import shutil +import random +import math + +import time +import datetime +import argparse +import distutils.util + +import numpy as np +import torch +from torch import nn +from torch import optim +from torch.utils.data import DataLoader +import torch.nn.functional as Func + +from model import Model +from pytorch.fixed import apply_fixed_architecture +from dataloader import read_data_sst + + +logger = logging.getLogger("nni.textnas") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--reset_output_dir", + type=distutils.util.strtobool, + default=True, + help="Whether to clean the output dir if existed. (default: %(default)s)") + parser.add_argument( + "--child_fixed_arc", + type=str, + required=True, + help="Architecture json file. (default: %(default)s)") + parser.add_argument( + "--data_path", + type=str, + default="data", + help="Directory containing the dataset and embedding file. (default: %(default)s)") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="The output directory. (default: %(default)s)") + parser.add_argument( + "--child_lr_decay_scheme", + type=str, + default="cosine", + help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="Number of samples each batch for training. (default: %(default)s)") + parser.add_argument( + "--eval_batch_size", + type=int, + default=128, + help="Number of samples each batch for evaluation. (default: %(default)s)") + parser.add_argument( + "--class_num", + type=int, + default=5, + help="The number of categories. (default: %(default)s)") + parser.add_argument( + "--global_seed", + type=int, + default=1234, + help="Seed for reproduction. (default: %(default)s)") + parser.add_argument( + "--max_input_length", + type=int, + default=64, + help="The maximum length of the sentence. (default: %(default)s)") + parser.add_argument( + "--num_epochs", + type=int, + default=10, + help="The number of training epochs. (default: %(default)s)") + parser.add_argument( + "--child_num_layers", + type=int, + default=24, + help="The layer number of the architecture. (default: %(default)s)") + parser.add_argument( + "--child_out_filters", + type=int, + default=256, + help="The dimension of hidden states. (default: %(default)s)") + parser.add_argument( + "--child_out_filters_scale", + type=int, + default=1, + help="The scale of hidden state dimension. (default: %(default)s)") + parser.add_argument( + "--child_lr_T_0", + type=int, + default=10, + help="The length of one cycle. (default: %(default)s)") + parser.add_argument( + "--child_lr_T_mul", + type=int, + default=2, + help="The multiplication factor per cycle. (default: %(default)s)") + parser.add_argument( + "--min_count", + type=int, + default=1, + help="The threshold to cut off low frequent words. (default: %(default)s)") + parser.add_argument( + "--train_ratio", + type=float, + default=1.0, + help="The sample ratio for the training set. (default: %(default)s)") + parser.add_argument( + "--valid_ratio", + type=float, + default=1.0, + help="The sample ratio for the dev set. (default: %(default)s)") + parser.add_argument( + "--child_grad_bound", + type=float, + default=5.0, + help="The threshold for gradient clipping. (default: %(default)s)") + parser.add_argument( + "--child_lr", + type=float, + default=0.02, + help="The initial learning rate. (default: %(default)s)") + parser.add_argument( + "--cnn_keep_prob", + type=float, + default=0.8, + help="Keep prob for cnn layer. (default: %(default)s)") + parser.add_argument( + "--final_output_keep_prob", + type=float, + default=1.0, + help="Keep prob for the last output layer. (default: %(default)s)") + parser.add_argument( + "--lstm_out_keep_prob", + type=float, + default=0.8, + help="Keep prob for the RNN layer. (default: %(default)s)") + parser.add_argument( + "--embed_keep_prob", + type=float, + default=0.8, + help="Keep prob for the embedding layer. (default: %(default)s)") + parser.add_argument( + "--attention_keep_prob", + type=float, + default=0.8, + help="Keep prob for the self-attention layer. (default: %(default)s)") + parser.add_argument( + "--child_l2_reg", + type=float, + default=3e-6, + help="Weight decay factor. (default: %(default)s)") + parser.add_argument( + "--child_lr_max", + type=float, + default=0.002, + help="The max learning rate. (default: %(default)s)") + parser.add_argument( + "--child_lr_min", + type=float, + default=0.001, + help="The min learning rate. (default: %(default)s)") + parser.add_argument( + "--child_optim_algo", + type=str, + default="adam", + help="Optimization algorithm. (default: %(default)s)") + parser.add_argument( + "--checkpoint_dir", + type=str, + default="best_checkpoint", + help="Path for saved checkpoints. (default: %(default)s)") + parser.add_argument( + "--output_type", + type=str, + default="avg", + help="Opertor type for the time steps reduction. (default: %(default)s)") + parser.add_argument( + "--multi_path", + type=distutils.util.strtobool, + default=False, + help="Search for multiple path in the architecture. (default: %(default)s)") + parser.add_argument( + "--is_binary", + type=distutils.util.strtobool, + default=False, + help="Binary label for sst dataset. (default: %(default)s)") + parser.add_argument( + "--is_cuda", + type=distutils.util.strtobool, + default=True, + help="Specify the device type. (default: %(default)s)") + parser.add_argument( + "--is_mask", + type=distutils.util.strtobool, + default=True, + help="Apply mask. (default: %(default)s)") + parser.add_argument( + "--fixed_seed", + type=distutils.util.strtobool, + default=True, + help="Fix the seed. (default: %(default)s)") + parser.add_argument( + "--load_checkpoint", + type=distutils.util.strtobool, + default=False, + help="Wether to load checkpoint. (default: %(default)s)") + parser.add_argument( + "--log_every", + type=int, + default=50, + help="How many steps to log. (default: %(default)s)") + parser.add_argument( + "--eval_every_epochs", + type=int, + default=1, + help="How many epochs to eval. (default: %(default)s)") + + global FLAGS + + FLAGS = parser.parse_args() + + +def set_random_seed(seed): + logger.info("set random seed for data reading: {}".format(seed)) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + if FLAGS.is_cuda: + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def get_model(embedding, num_layers): + logger.info("num layers: {0}".format(num_layers)) + assert FLAGS.child_fixed_arc is not None, "Architecture should be provided." + + child_model = Model( + embedding=embedding, + hidden_units=FLAGS.child_out_filters_scale * FLAGS.child_out_filters, + num_layers=num_layers, + num_classes=FLAGS.class_num, + choose_from_k=5 if FLAGS.multi_path else 1, + lstm_keep_prob=FLAGS.lstm_out_keep_prob, + cnn_keep_prob=FLAGS.cnn_keep_prob, + att_keep_prob=FLAGS.attention_keep_prob, + att_mask=FLAGS.is_mask, + embed_keep_prob=FLAGS.embed_keep_prob, + final_output_keep_prob=FLAGS.final_output_keep_prob, + global_pool=FLAGS.output_type) + + apply_fixed_architecture(child_model, FLAGS.child_fixed_arc) + return child_model + + +def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None): + if eval_set == "test": + assert test_dataloader is not None + dataloader = test_dataloader + elif eval_set == "valid": + assert valid_dataloader is not None + dataloader = valid_dataloader + else: + raise NotImplementedError("Unknown eval_set '{}'".format(eval_set)) + + tot_acc = 0 + tot = 0 + losses = [] + + with torch.no_grad(): # save memory + for batch in dataloader: + (sent_ids, mask), labels = batch + + sent_ids = sent_ids.to(device, non_blocking=True) + mask = mask.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + + logits = child_model((sent_ids, mask)) # run + + loss = criterion(logits, labels.long()) + loss = loss.mean() + preds = logits.argmax(dim=1).long() + acc = torch.eq(preds, labels.long()).long().sum().item() + + losses.append(loss) + tot_acc += acc + tot += len(labels) + + losses = torch.tensor(losses) + loss = losses.mean() + if tot > 0: + final_acc = float(tot_acc) / tot + else: + final_acc = 0 + logger.info("Error in calculating final_acc") + return final_acc, loss + + +def print_user_flags(FLAGS, line_limit=80): + log_strings = "\n" + "-" * line_limit + "\n" + for flag_name in sorted(vars(FLAGS)): + value = "{}".format(getattr(FLAGS, flag_name)) + log_string = flag_name + log_string += "." * (line_limit - len(flag_name) - len(value)) + log_string += value + log_strings = log_strings + log_string + log_strings = log_strings + "\n" + log_strings += "-" * line_limit + logger.info(log_strings) + + +def count_model_params(trainable_params): + num_vars = 0 + for var in trainable_params: + num_vars += np.prod([dim for dim in var.size()]) + return num_vars + + +def update_lr( + optimizer, + epoch, + l2_reg=1e-4, + lr_warmup_val=None, + lr_init=0.1, + lr_decay_scheme="cosine", + lr_max=0.002, + lr_min=0.000000001, + lr_T_0=4, + lr_T_mul=1, + sync_replicas=False, + num_aggregate=None, + num_replicas=None): + if lr_decay_scheme == "cosine": + assert lr_max is not None, "Need lr_max to use lr_cosine" + assert lr_min is not None, "Need lr_min to use lr_cosine" + assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine" + assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine" + + T_i = lr_T_0 + t_epoch = epoch + last_reset = 0 + while True: + t_epoch -= T_i + if t_epoch < 0: + break + last_reset += T_i + T_i *= lr_T_mul + + T_curr = epoch - last_reset + + def _update(): + rate = T_curr / T_i * 3.1415926 + lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate)) + return lr + + learning_rate = _update() + else: + raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme)) + + #update lr in optimizer + for params_group in optimizer.param_groups: + params_group['lr'] = learning_rate + return learning_rate + + +def train(device, data_path, output_dir, num_layers): + logger.info("Build dataloader") + train_dataset, valid_dataset, test_dataset, embedding = \ + read_data_sst(data_path, + FLAGS.max_input_length, + FLAGS.min_count, + train_ratio=FLAGS.train_ratio, + valid_ratio=FLAGS.valid_ratio, + is_binary=FLAGS.is_binary) + train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=True) + test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True) + valid_dataloader = DataLoader(valid_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True) + + logger.info("Build model") + child_model = get_model(embedding, num_layers) + logger.info("Finish build model") + + #for name, var in child_model.named_parameters(): + # logger.info(name, var.size(), var.requires_grad) # output all params + + num_vars = count_model_params(child_model.parameters()) + logger.info("Model has {} params".format(num_vars)) + + for m in child_model.modules(): # initializer + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + + criterion = nn.CrossEntropyLoss() + + # get optimizer + if FLAGS.child_optim_algo == "adam": + optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) # with L2 + else: + raise ValueError("Unknown optim_algo {}".format(FLAGS.child_optim_algo)) + + child_model.to(device) + criterion.to(device) + + logger.info("Start training") + start_time = time.time() + step = 0 + + # save path + model_save_path = os.path.join(FLAGS.output_dir, "model.pth") + best_model_save_path = os.path.join(FLAGS.output_dir, "best_model.pth") + best_acc = 0 + start_epoch = 0 + if FLAGS.load_checkpoint: + if os.path.isfile(model_save_path): + checkpoint = torch.load(model_save_path, map_location = torch.device('cpu')) + step = checkpoint['step'] + start_epoch = checkpoint['epoch'] + child_model.load_state_dict(checkpoint['child_model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + for epoch in range(start_epoch, FLAGS.num_epochs): + lr = update_lr(optimizer, + epoch, + l2_reg=FLAGS.child_l2_reg, + lr_warmup_val=None, + lr_init=FLAGS.child_lr, + lr_decay_scheme=FLAGS.child_lr_decay_scheme, + lr_max=FLAGS.child_lr_max, + lr_min=FLAGS.child_lr_min, + lr_T_0=FLAGS.child_lr_T_0, + lr_T_mul=FLAGS.child_lr_T_mul) + child_model.train() + for batch in train_dataloader: + (sent_ids, mask), labels = batch + + sent_ids = sent_ids.to(device, non_blocking=True) + mask = mask.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + + step += 1 + + logits = child_model((sent_ids, mask)) # run + + loss = criterion(logits, labels.long()) + loss = loss.mean() + preds = logits.argmax(dim=1).long() + acc = torch.eq(preds, labels.long()).long().sum().item() + + optimizer.zero_grad() + loss.backward() + grad_norm = 0 + trainable_params = child_model.parameters() + + assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients." + # compute the gradient norm value + grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999) + for param in trainable_params: + nn.utils.clip_grad_norm_(param, FLAGS.child_grad_bound) # clip grad + + optimizer.step() + + if step % FLAGS.log_every == 0: + curr_time = time.time() + log_string = "" + log_string += "epoch={:<6d}".format(epoch) + log_string += "ch_step={:<6d}".format(step) + log_string += " loss={:<8.6f}".format(loss) + log_string += " lr={:<8.4f}".format(lr) + log_string += " |g|={:<8.4f}".format(grad_norm) + log_string += " tr_acc={:<3d}/{:>3d}".format(acc, logits.size()[0]) + log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60) + logger.info(log_string) + + epoch += 1 + save_state = { + 'step' : step, + 'epoch' : epoch, + 'child_model_state_dict' : child_model.state_dict(), + 'optimizer_state_dict' : optimizer.state_dict()} + torch.save(save_state, model_save_path) + child_model.eval() + logger.info("Epoch {}: Eval".format(epoch)) + eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader) + logger.info("ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss)) + if eval_acc > best_acc: + best_acc = eval_acc + logger.info("Save best model") + save_state = { + 'step' : step, + 'epoch' : epoch, + 'child_model_state_dict' : child_model.state_dict(), + 'optimizer_state_dict' : optimizer.state_dict()} + torch.save(save_state, best_model_save_path) + + return eval_acc + + +def main(): + parse_args() + if not os.path.isdir(FLAGS.output_dir): + logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir)) + os.makedirs(FLAGS.output_dir) + elif FLAGS.reset_output_dir: + logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir)) + shutil.rmtree(FLAGS.output_dir, ignore_errors=True) + os.makedirs(FLAGS.output_dir) + + print_user_flags(FLAGS) + + if FLAGS.fixed_seed: + set_random_seed(FLAGS.global_seed) + + device = torch.device("cuda" if FLAGS.is_cuda else "cpu") + train(device, FLAGS.data_path, FLAGS.output_dir, FLAGS.child_num_layers) + + +if __name__ == "__main__": + main() diff --git a/dubhe-tadl/textnas/search.py b/dubhe-tadl/textnas/search.py new file mode 100644 index 0000000..f0ed6f3 --- /dev/null +++ b/dubhe-tadl/textnas/search.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import os +import random +from argparse import ArgumentParser +from itertools import cycle + +import numpy as np +import torch +import torch.nn as nn +import sys +sys.path.append("../..") +from pytorch.enas import EnasMutator, EnasTrainer +from pytorch.callbacks import LRSchedulerCallback +from pytorch.mutables import LayerChoice, InputChoice, MutableScope + +from dataloader import read_data_sst +from model import Model +from utils import accuracy, dump_global_result + +from collections import OrderedDict +import os +import json +import time + +logger = logging.getLogger("nni.textnas") +logger.setLevel(logging.INFO) + +# For debugging mode +# os.chdir('/home/yangyi/pytorch/textnas') +os.environ["CUDA_VISIBLE_DEVICES"]='4' + + +def save_textnas_search_space(mutator,file_path): + result = OrderedDict() + cur_layer_idx = None + for mutable in mutator.mutables.traverse(): + if not isinstance(mutable,(LayerChoice, InputChoice)): + cur_layer_idx = mutable.key + continue + if isinstance(mutable,LayerChoice): + if 'op_list' not in result: + result['op_list'] = [str(i) for i in mutable] + result[cur_layer_idx+ '_'+ mutable.key] = 'op_list' + + else: + result[cur_layer_idx+ '_'+ mutable.key] = {'skip_connection':False if mutable.n_chosen else True, + 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', + 'choose_from': mutable.choose_from if mutable.choose_from else ''} + + + dump_global_result(file_path,result) + + +class TextNASTrainer(EnasTrainer): + def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs): + super().__init__(*args, **kwargs) + self.train_loader = train_loader + self.valid_loader = valid_loader + self.test_loader = test_loader + self.result = {'accuracy':[], + 'cost_time':0} + def init_dataloader(self): + pass + + + +if __name__ == "__main__": + parser = ArgumentParser("textnas") + parser.add_argument("--search_space_path", type=str, + default='./search_space.json', help="search_space directory") + parser.add_argument("--selected_space_path", type=str, + default='./selected_space.json', help="sapce_path_out directory") + parser.add_argument("--result_path", type=str, + default='./result.json', help="res directory") + parser.add_argument('--trial_id', type=int, default=0, metavar='N', + help='trial_id,start from 0') + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=50, type=int) + parser.add_argument("--epochs", default=2, type=int) + parser.add_argument("--lr", default=5e-3, type=float) + args = parser.parse_args() + # 设置随机种子 + torch.manual_seed(args.trial_id) + torch.cuda.manual_seed_all(args.trial_id) + np.random.seed(args.trial_id) + random.seed(args.trial_id) + # use deterministic instead of nondeterministic algorithm + # make sure exact results can be reproduced everytime. + torch.backends.cudnn.deterministic = True + + + + + # 配置计算资源及load数据 + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data") + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) + valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4) + train_loader, valid_loader = cycle(train_loader), cycle(valid_loader) + + + # 导入模型以及预训练的词向量 + model = Model(embedding) + + + # 实例化一个mutator, mutator主要是用于选择搜索空间的 + mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean") + + # 储存整个网络结构 + save_textnas_search_space(mutator, args.search_space_path) + + criterion = nn.CrossEntropyLoss() + # 实例化优化器 + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6) + # 实例化学习率变化器 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5) + # 实例话一个训练器 + trainer = TextNASTrainer(model, + loss=criterion, + metrics=lambda output, target: {"acc": accuracy(output, target)}, + reward_function=accuracy, + optimizer=optimizer, + callbacks=[LRSchedulerCallback(lr_scheduler)], + batch_size=args.batch_size, + num_epochs=args.epochs, + dataset_train=None, + dataset_valid=None, + train_loader=train_loader, + valid_loader=valid_loader, + test_loader=test_loader, + log_frequency=args.log_frequency, + mutator=mutator, + mutator_lr=2e-3, + mutator_steps=5, + mutator_steps_aggregate=1, + child_steps=50, + baseline_decay=0.99, + test_arc_per_epoch=10) + + + logger.info(trainer.metrics) + + t1 = time.time() + trainer.train() + trainer.result["cost_time"] = time.time() - t1 + dump_global_result(args.result_path,trainer.result) + + # os.makedirs("checkpoints", exist_ok=True) + # for i in range(2): + # trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i)) + + selected_model = trainer.export_child_model(selected_space = True) + dump_global_result(args.selected_space_path,selected_model) \ No newline at end of file diff --git a/dubhe-tadl/textnas/utils.py b/dubhe-tadl/textnas/utils.py new file mode 100644 index 0000000..459c251 --- /dev/null +++ b/dubhe-tadl/textnas/utils.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +import torch +import torch.nn as nn +import json + +INF = 1E10 +EPS = 1E-12 + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +def get_length(mask): + length = torch.sum(mask, 1) + length = length.long().cpu() + return length + + +class GlobalAvgPool(nn.Module): + def forward(self, x, mask): + x = torch.sum(x, 2) + length = torch.sum(mask, 1, keepdim=True).float() + length += torch.eq(length, 0.0).float() * EPS + length = length.repeat(1, x.size()[1]) + x /= length + return x + + +class GlobalMaxPool(nn.Module): + def forward(self, x, mask): + mask = torch.eq(mask.float(), 0.0).long() + mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1) + mask *= -INF + x += mask + x, _ = torch.max(x + mask, 2) + return x + + +class IteratorWrapper: + def __init__(self, loader): + self.loader = loader + self.iterator = None + + def __iter__(self): + self.iterator = iter(self.loader) + return self + + def __len__(self): + return len(self.loader) + + def __next__(self): + data = next(self.iterator) + text, length = data.text + max_length = text.size(1) + label = data.label - 1 + bs = label.size(0) + mask = torch.arange(max_length, device=length.device).unsqueeze(0).repeat(bs, 1) + mask = mask < length.unsqueeze(-1).repeat(1, max_length) + return (text, mask), label + + +def accuracy(output, target): + batch_size = target.size(0) + _, predicted = torch.max(output.data, 1) + return (predicted == target).sum().item() / batch_size + + +def dump_global_result(res_path,global_result, sort_keys = False): + with open(res_path, "w") as ss_file: + json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) diff --git a/dubhe-tadl/trainer.py b/dubhe-tadl/trainer.py new file mode 100644 index 0000000..8a40805 --- /dev/null +++ b/dubhe-tadl/trainer.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +import os +import time +from abc import abstractmethod + +import torch + +from .base_trainer import BaseTrainer + +_logger = logging.getLogger(__name__) + +class TorchTensorEncoder(json.JSONEncoder): + def default(self, o): # pylint: disable=method-hidden + if isinstance(o, torch.Tensor): + olist = o.tolist() + if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)): + _logger.warning("Every element in %s is either 0 or 1. " + "You might consider convert it into bool.", olist) + return olist + return super().default(o) + + +class Trainer(BaseTrainer): + """ + A trainer with some helper functions implemented. To implement a new trainer, + users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`. + + Parameters + ---------- + model : nn.Module + Model with mutables. + mutator : BaseMutator + A mutator object that has been initialized with the model. + loss : callable + Called with logits and targets. Returns a loss tensor. + See `PyTorch loss functions`_ for examples. + metrics : callable + Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example, + + .. code-block:: python + + def metrics_fn(output, target): + return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)} + + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + dataset_train : torch.utils.data.Dataset + Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard + PyTorch Dataset. See `torch.utils.data`_ for examples. + dataset_valid : torch.utils.data.Dataset + Dataset of validation/testing. + batch_size : int + Batch size. + workers : int + Number of workers used in data preprocessing. + device : torch.device + Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will + automatic detects GPU and selects GPU first. + log_frequency : int + Number of mini-batches to log metrics. + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + + + .. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions + .. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html + """ + def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, + dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + self.model = model + self.mutator = mutator + self.loss = loss + + self.metrics = metrics + self.optimizer = optimizer + + self.model.to(self.device) + self.mutator.to(self.device) + self.loss.to(self.device) + self.num_epochs = num_epochs + self.dataset_train = dataset_train + self.dataset_valid = dataset_valid + self.batch_size = batch_size + self.workers = workers + self.log_frequency = log_frequency + # self.log_dir = os.path.join("logs", str(time.time())) + # os.makedirs(self.log_dir, exist_ok=True) + # self.status_writer = open(os.path.join(self.log_dir, "log"), "w") + self.callbacks = callbacks if callbacks is not None else [] + for callback in self.callbacks: + callback.build(self.model, self.optimizer, self.mutator, self) + + @abstractmethod + def train_one_epoch(self, epoch): + """ + Train one epoch. + + Parameters + ---------- + epoch : int + Epoch number starting from 0. + """ + pass + + @abstractmethod + def validate_one_epoch(self, epoch): + """ + Validate one epoch. + + Parameters + ---------- + epoch : int + Epoch number starting from 0. + """ + pass + + def train(self, validate=True): + """ + Train ``num_epochs``. + Trigger callbacks at the start and the end of each epoch. + + Parameters + ---------- + validate : bool + If ``true``, will do validation every epoch. + """ + for epoch in range(self.num_epochs): + for callback in self.callbacks: + callback.on_epoch_begin(epoch) + + # training + _logger.info("Epoch %d Training", epoch + 1) + self.train_one_epoch(epoch) + if validate: + # validation + _logger.info("Epoch %d Validating", epoch + 1) + self.validate_one_epoch(epoch) + + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def validate(self): + """ + Do one validation. + """ + self.validate_one_epoch(-1) + + def export(self, file): + """ + Call ``mutator.export()`` and dump the architecture to ``file``. + + Parameters + ---------- + file : str + A file path. Expected to be a JSON. + """ + mutator_export = self.mutator.export() + with open(file, "w") as f: + json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) + + def checkpoint(self): + """ + Return trainer checkpoint. + """ + raise NotImplementedError("Not implemented yet") + + # 暂时还没确定graph输出格式 + # def enable_visualization(self): + # """ + # Enable visualization. Write graph and training log to folder ``logs/``. + # """ + # sample = None + # for x, _ in self.train_loader: + # sample = x.to(self.device)[:2] + # break + # if sample is None: + # _logger.warning("Sample is %s.", sample) + # _logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir) + # with open(os.path.join(self.log_dir, "graph.json"), "w") as f: + # json.dump(self.mutator.graph(sample), f) + # self.visualization_enabled = True + + # def _write_graph_status(self): + # if hasattr(self, "visualization_enabled") and self.visualization_enabled: + # print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True) diff --git a/dubhe-tadl/utils.py b/dubhe-tadl/utils.py new file mode 100644 index 0000000..5ceb579 --- /dev/null +++ b/dubhe-tadl/utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import OrderedDict +import json +import random +import numpy as np +import torch +import torch.nn as nn +import os +from datetime import datetime +from io import TextIOBase +import logging +import sys +import time +from pytorch.trainer import TorchTensorEncoder + +_counter = 0 + + +def global_mutable_counting(): + """ + A program level counter starting from 1. + """ + global _counter + _counter += 1 + return _counter + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def _reset_global_mutable_counting(): + """ + Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys. + """ + global _counter + _counter = 0 + + +def to_device(obj, device): + """ + Move a tensor, tuple, list, or dict onto device. + """ + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, tuple): + return tuple(to_device(t, device) for t in obj) + if isinstance(obj, list): + return [to_device(t, device) for t in obj] + if isinstance(obj, dict): + return {k: to_device(v, device) for k, v in obj.items()} + if isinstance(obj, (int, float, str)): + return obj + raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) + + +def to_list(arr): + if torch.is_tensor(arr): + return arr.cpu().numpy().tolist() + if isinstance(arr, np.ndarray): + return arr.tolist() + if isinstance(arr, (list, tuple)): + return list(arr) + return arr + + +def count_parameters_in_MB(model): + return np.sum( + np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 + + +def str2bool(str): + return True if str.lower() == 'true' else False + + +class AverageMeterGroup: + """ + Average meter group for multiple average meters. + """ + + def __init__(self): + self.meters = OrderedDict() + + def update(self, data): + """ + Update the meter group with a dict of metrics. + Non-exist average meters will be automatically created. + """ + for k, v in data.items(): + if k not in self.meters: + self.meters[k] = AverageMeter(k, ":4f") + self.meters[k].update(v) + + def __getattr__(self, item): + return self.meters[item] + + def __getitem__(self, item): + return self.meters[item] + + def __str__(self): + return " ".join(str(v) for v in self.meters.values()) + + def summary(self): + """ + Return a summary string of group data. + """ + return " ".join(v.summary() for v in self.meters.values()) + + def get_last_acc(self): + return float([v.summary() for v in self.meters.values()][0].split(': ')[1]) + + +class AverageMeter: + """ + Computes and stores the average and current value. + + Parameters + ---------- + name : str + Name to display. + fmt : str + Format string to print the values. + """ + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + """ + Reset the meter. + """ + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """ + Update with value and weight. + + Parameters + ---------- + val : float or int + The new value to be accounted in. + n : int + The weight of the new value. + """ + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + def summary(self): + fmtstr = '{name}: {avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + + +class StructuredMutableTreeNode: + """ + A structured representation of a search space. + A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. + This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, + the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a + ``Mutable`` (other than ``MutableScope``). + + Parameters + ---------- + mutable : nni.nas.pytorch.mutables.Mutable + The mutable that current node is linked with. + """ + + def __init__(self, mutable): + self.mutable = mutable + self.children = [] + + def add_child(self, mutable): + """ + Add a tree node to the children list of current node. + """ + self.children.append(StructuredMutableTreeNode(mutable)) + return self.children[-1] + + def type(self): + """ + Return the ``type`` of mutable content. + """ + return type(self.mutable) + + def __iter__(self): + return self.traverse() + + def traverse(self, order="pre", deduplicate=True, memo=None): + """ + Return a generator that generates a list of mutables in this tree. + + Parameters + ---------- + order : str + pre or post. If pre, current mutable is yield before children. Otherwise after. + deduplicate : bool + If true, mutables with the same key will not appear after the first appearance. + memo : dict + An auxiliary dict that memorize keys seen before, so that deduplication is possible. + + Returns + ------- + generator of Mutable + """ + if memo is None: + memo = set() + assert order in ["pre", "post"] + if order == "pre": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable + for child in self.children: + for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): + yield m + if order == "post": + if self.mutable is not None: + if not deduplicate or self.mutable.key not in memo: + memo.add(self.mutable.key) + yield self.mutable + + +def dump_global_result(res_path, global_result): + with open(res_path, "w") as ss_file: + json.dump(global_result, ss_file, indent=2, cls=TorchTensorEncoder) + + +def save_best_checkpoint(checkpoint_dir, model, optimizer, epoch): + """ + Dump to 'best_checkpoint_epoch{}.pth.tar'.format(epoch)' on last epoch end. + ``DataParallel`` object will have their inside modules exported. + """ + if isinstance(model, nn.DataParallel): + child_model_state_dict = model.module.state_dict() + else: + child_model_state_dict = model.state_dict() + + save_state = {'child_model_state_dict': child_model_state_dict, + 'optimizer_state_dict': optimizer.state_dict(), + 'epoch': epoch} + + dest_path = os.path.join(checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)) + torch.save(save_state, dest_path) + + +log_level_map = { + 'fatal': logging.FATAL, + 'error': logging.ERROR, + 'warning': logging.WARNING, + 'info': logging.INFO, + 'debug': logging.DEBUG +} + +_time_format = '%m/%d/%Y, %I:%M:%S %p' + + +class _LoggerFileWrapper(TextIOBase): + def __init__(self, logger_file): + self.file = logger_file + + def write(self, s): + if s != '\n': + cur_time = datetime.now().strftime(_time_format) + self.file.write('[{}] PRINT '.format(cur_time) + s + '\n') + self.file.flush() + return len(s) + + +def init_logger(logger_file_path, log_level_name='info'): + """Initialize root logger. + This will redirect anything from logging.getLogger() as well as stdout to specified file. + logger_file_path: path of logger file (path-like object). + """ + + log_level = log_level_map.get(log_level_name) + logger_file = open(logger_file_path, 'w') + fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' + logging.Formatter.converter = time.localtime + formatter = logging.Formatter(fmt, _time_format) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + file_handler = logging.FileHandler(logger_file_path) + file_handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(stream_handler) + root_logger.addHandler(file_handler) + root_logger.setLevel(log_level) + + # include print function output + sys.stdout = _LoggerFileWrapper(logger_file) + + +def mkdirs(*args): + for path in args: + dirname = os.path.dirname(path) + if dirname and not os.path.exists(dirname): + print("make {} in dir: {}".format(path, dirname)) + os.makedirs(dirname) + + +def list_str2int(ls): + return list(map(lambda x: int(x), ls))