@@ -0,0 +1,2 @@ | |||||
# from .log import init_logger | |||||
# init_logger() |
@@ -0,0 +1,159 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
import torch.nn as nn | |||||
from .mutables import Mutable, MutableScope, InputChoice | |||||
from .utils import StructuredMutableTreeNode | |||||
logger = logging.getLogger(__name__) | |||||
logger.setLevel(logging.INFO) | |||||
class BaseMutator(nn.Module): | |||||
""" | |||||
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing | |||||
callbacks that are called in ``forward`` in mutables. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to apply mutator on. | |||||
""" | |||||
def __init__(self, model): | |||||
super().__init__() | |||||
self.__dict__["model"] = model | |||||
self._structured_mutables = self._parse_search_space(self.model) | |||||
def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None): | |||||
if memo is None: | |||||
memo = set() | |||||
if root is None: | |||||
root = StructuredMutableTreeNode(None) | |||||
if module not in memo: | |||||
memo.add(module) | |||||
if isinstance(module, Mutable): | |||||
if nested_detection is not None: | |||||
raise RuntimeError("Cannot have nested search space. Error at {} in {}" | |||||
.format(module, nested_detection)) | |||||
module.name = prefix | |||||
module.set_mutator(self) | |||||
root = root.add_child(module) | |||||
if not isinstance(module, MutableScope): | |||||
nested_detection = module | |||||
if isinstance(module, InputChoice): | |||||
for k in module.choose_from: | |||||
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: | |||||
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY." | |||||
.format(k, module.key)) | |||||
for name, submodule in module._modules.items(): | |||||
if submodule is None: | |||||
continue | |||||
submodule_prefix = prefix + ("." if prefix else "") + name | |||||
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, | |||||
nested_detection=nested_detection) | |||||
return root | |||||
@property | |||||
def mutables(self): | |||||
""" | |||||
A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`. | |||||
Modules are yielded in the order that they are defined in ``__init__``. | |||||
For mutables with their keys appearing multiple times, only the first one will appear. | |||||
""" | |||||
return self._structured_mutables | |||||
@property | |||||
def undedup_mutables(self): | |||||
return self._structured_mutables.traverse(deduplicate=False) | |||||
def forward(self, *inputs): | |||||
""" | |||||
Warnings | |||||
-------- | |||||
Don't call forward of a mutator. | |||||
""" | |||||
raise RuntimeError("Forward is undefined for mutators.") | |||||
def __setattr__(self, name, value): | |||||
if name == "model": | |||||
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to " | |||||
"include you network, as it will include all parameters in model into the mutator.") | |||||
return super().__setattr__(name, value) | |||||
def enter_mutable_scope(self, mutable_scope): | |||||
""" | |||||
Callback when forward of a MutableScope is entered. | |||||
Parameters | |||||
---------- | |||||
mutable_scope : MutableScope | |||||
The mutable scope that is entered. | |||||
""" | |||||
pass | |||||
def exit_mutable_scope(self, mutable_scope): | |||||
""" | |||||
Callback when forward of a MutableScope is exited. | |||||
Parameters | |||||
---------- | |||||
mutable_scope : MutableScope | |||||
The mutable scope that is exited. | |||||
""" | |||||
pass | |||||
def on_forward_layer_choice(self, mutable, *args, **kwargs): | |||||
""" | |||||
Callbacks of forward in LayerChoice. | |||||
Parameters | |||||
---------- | |||||
mutable : LayerChoice | |||||
Module whose forward is called. | |||||
args : list of torch.Tensor | |||||
The arguments of its forward function. | |||||
kwargs : dict | |||||
The keyword arguments of its forward function. | |||||
Returns | |||||
------- | |||||
tuple of torch.Tensor and torch.Tensor | |||||
Output tensor and mask. | |||||
""" | |||||
raise NotImplementedError | |||||
def on_forward_input_choice(self, mutable, tensor_list): | |||||
""" | |||||
Callbacks of forward in InputChoice. | |||||
Parameters | |||||
---------- | |||||
mutable : InputChoice | |||||
Mutable that is called. | |||||
tensor_list : list of torch.Tensor | |||||
The arguments mutable is called with. | |||||
Returns | |||||
------- | |||||
tuple of torch.Tensor and torch.Tensor | |||||
Output tensor and mask. | |||||
""" | |||||
raise NotImplementedError | |||||
def export(self): | |||||
""" | |||||
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole | |||||
network can be fully determined with these decisions for further training from scratch. | |||||
Returns | |||||
------- | |||||
dict | |||||
Mappings from mutable keys to decisions. | |||||
""" | |||||
raise NotImplementedError |
@@ -0,0 +1,40 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from abc import ABC, abstractmethod | |||||
class BaseTrainer(ABC): | |||||
@abstractmethod | |||||
def train(self): | |||||
""" | |||||
Override the method to train. | |||||
""" | |||||
raise NotImplementedError | |||||
@abstractmethod | |||||
def validate(self): | |||||
""" | |||||
Override the method to validate. | |||||
""" | |||||
raise NotImplementedError | |||||
@abstractmethod | |||||
def export(self, file): | |||||
""" | |||||
Override the method to export to file. | |||||
Parameters | |||||
---------- | |||||
file : str | |||||
File path to export to. | |||||
""" | |||||
raise NotImplementedError | |||||
@abstractmethod | |||||
def checkpoint(self): | |||||
""" | |||||
Override to dump a checkpoint. | |||||
""" | |||||
raise NotImplementedError |
@@ -0,0 +1,167 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
import os | |||||
import torch | |||||
import torch.nn as nn | |||||
_logger = logging.getLogger(__name__) | |||||
_logger.setLevel(logging.INFO) | |||||
class Callback: | |||||
""" | |||||
Callback provides an easy way to react to events like begin/end of epochs. | |||||
""" | |||||
def __init__(self): | |||||
self.model = None | |||||
self.optimizer = None | |||||
self.mutator = None | |||||
self.trainer = None | |||||
def build(self, model, optimizer, mutator, trainer): | |||||
""" | |||||
Callback needs to be built with model, mutator, trainer, to get updates from them. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
Model to be trained. | |||||
mutator : nn.Module | |||||
Mutator that mutates the model. | |||||
trainer : BaseTrainer | |||||
Trainer that is to call the callback. | |||||
""" | |||||
self.model = model | |||||
self.optimizer = optimizer | |||||
self.mutator = mutator | |||||
self.trainer = trainer | |||||
def on_epoch_begin(self, epoch): | |||||
""" | |||||
Implement this to do something at the begin of epoch. | |||||
Parameters | |||||
---------- | |||||
epoch : int | |||||
Epoch number, starting from 0. | |||||
""" | |||||
pass | |||||
def on_epoch_end(self, epoch): | |||||
""" | |||||
Implement this to do something at the end of epoch. | |||||
Parameters | |||||
---------- | |||||
epoch : int | |||||
Epoch number, starting from 0. | |||||
""" | |||||
pass | |||||
def on_batch_begin(self, epoch): | |||||
pass | |||||
def on_batch_end(self, epoch): | |||||
pass | |||||
class LRSchedulerCallback(Callback): | |||||
""" | |||||
Calls scheduler on every epoch ends. | |||||
Parameters | |||||
---------- | |||||
scheduler : LRScheduler | |||||
Scheduler to be called. | |||||
""" | |||||
def __init__(self, scheduler, mode="epoch"): | |||||
super().__init__() | |||||
assert mode == "epoch" | |||||
self.scheduler = scheduler | |||||
self.mode = mode | |||||
def on_epoch_end(self, epoch): | |||||
""" | |||||
Call ``self.scheduler.step()`` on epoch end. | |||||
""" | |||||
self.scheduler.step() | |||||
class ArchitectureCheckpoint(Callback): | |||||
""" | |||||
Calls ``trainer.export()`` on every epoch ends. | |||||
Parameters | |||||
---------- | |||||
checkpoint_dir : str | |||||
Location to save checkpoints. | |||||
""" | |||||
def __init__(self, checkpoint_dir): | |||||
super().__init__() | |||||
self.checkpoint_dir = checkpoint_dir | |||||
os.makedirs(self.checkpoint_dir, exist_ok=True) | |||||
def on_epoch_end(self, epoch): | |||||
""" | |||||
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end. | |||||
""" | |||||
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) | |||||
_logger.info("Saving architecture to %s", dest_path) | |||||
self.trainer.export(dest_path) | |||||
class BestArchitectureCheckpoint(Callback): | |||||
""" | |||||
Calls ``trainer.export()`` on final epoch ends. | |||||
Parameters | |||||
---------- | |||||
checkpoint_path : str | |||||
Location to save checkpoints. | |||||
""" | |||||
def __init__(self, checkpoint_path, epoches): | |||||
super().__init__() | |||||
self.epoches = epoches | |||||
self.checkpoint_path = checkpoint_path | |||||
def on_epoch_end(self, epoch): | |||||
""" | |||||
Dump to ``./best_selected_space.json`` on epoch end. | |||||
""" | |||||
if epoch == self.epoches -1: | |||||
_logger.info("Saving architecture to %s", self.checkpoint_path) | |||||
self.trainer.export(self.checkpoint_path) | |||||
class ModelCheckpoint(Callback): | |||||
""" | |||||
Calls ``trainer.export()`` on every epoch ends. | |||||
Parameters | |||||
---------- | |||||
checkpoint_dir : str | |||||
Location to save checkpoints. | |||||
""" | |||||
def __init__(self, checkpoint_dir): | |||||
super().__init__() | |||||
self.checkpoint_dir = checkpoint_dir | |||||
os.makedirs(self.checkpoint_dir, exist_ok=True) | |||||
def on_epoch_end(self, epoch): | |||||
""" | |||||
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end. | |||||
``DataParallel`` object will have their inside modules exported. | |||||
""" | |||||
if isinstance(self.model, nn.DataParallel): | |||||
child_model_state_dict = self.model.module.state_dict() | |||||
else: | |||||
child_model_state_dict = self.model.state_dict() | |||||
save_state = {'child_model_state_dict': child_model_state_dict, | |||||
'optimizer_state_dict': self.optimizer.state_dict(), | |||||
'epoch': epoch} | |||||
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) | |||||
_logger.info("Saving model to %s", dest_path) | |||||
torch.save(save_state, dest_path) |
@@ -0,0 +1,150 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import json | |||||
import logging | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from pytorch.mutables import InputChoice, LayerChoice, MutableScope | |||||
from pytorch.mutator import Mutator | |||||
from pytorch.utils import to_list | |||||
_logger = logging.getLogger(__name__) | |||||
#_logger.setLevel(logging.INFO) | |||||
class FixedArchitecture(Mutator): | |||||
""" | |||||
Fixed architecture mutator that always selects a certain graph. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
A mutable network. | |||||
fixed_arc : dict | |||||
Preloaded architecture object. | |||||
strict : bool | |||||
Force everything that appears in ``fixed_arc`` to be used at least once. | |||||
""" | |||||
def __init__(self, model, fixed_arc, strict=True): | |||||
super().__init__(model) | |||||
self._fixed_arc = fixed_arc | |||||
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) | |||||
fixed_arc_keys = set(self._fixed_arc.keys()) | |||||
if fixed_arc_keys - mutable_keys: | |||||
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | |||||
if mutable_keys - fixed_arc_keys: | |||||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | |||||
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) | |||||
def _from_human_readable_architecture(self, human_arc): | |||||
# convert from an exported architecture | |||||
#print('human_arc',human_arc) | |||||
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. | |||||
#print('result_arc',result_arc) | |||||
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, | |||||
# which means {"op1": [0, ]} ir {"op1": ["conv", ]} | |||||
result_arc = {k: v['_value'] if isinstance(v['_value'], list) else [v['_value']] for k, v in result_arc.items()} | |||||
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format. | |||||
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. | |||||
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length. | |||||
for mutable in self.mutables: | |||||
if mutable.key not in result_arc: | |||||
continue # skip silently | |||||
choice_arr = result_arc[mutable.key] | |||||
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): | |||||
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ | |||||
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): | |||||
# multihot, do nothing | |||||
continue | |||||
if isinstance(mutable, LayerChoice): | |||||
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] | |||||
choice_arr = [i in choice_arr for i in range(len(mutable))] | |||||
elif isinstance(mutable, InputChoice): | |||||
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] | |||||
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] | |||||
result_arc[mutable.key] = choice_arr | |||||
return result_arc | |||||
def sample_search(self): | |||||
""" | |||||
Always returns the fixed architecture. | |||||
""" | |||||
return self._fixed_arc | |||||
def sample_final(self): | |||||
""" | |||||
Always returns the fixed architecture. | |||||
""" | |||||
return self._fixed_arc | |||||
def replace_layer_choice(self, module=None, prefix=""): | |||||
""" | |||||
Replace layer choices with selected candidates. It's done with best effort. | |||||
In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them. | |||||
If single choice, replace the module with a normal module. | |||||
Parameters | |||||
---------- | |||||
module : nn.Module | |||||
Module to be processed. | |||||
prefix : str | |||||
Module name under global namespace. | |||||
""" | |||||
if module is None: | |||||
module = self.model | |||||
for name, mutable in module.named_children(): | |||||
global_name = (prefix + "." if prefix else "") + name | |||||
if isinstance(mutable, LayerChoice): | |||||
chosen = self._fixed_arc[mutable.key] | |||||
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: | |||||
# sum is one, max is one, there has to be an only one | |||||
# this is compatible with both integer arrays, boolean arrays and float arrays | |||||
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) | |||||
setattr(module, name, mutable[chosen.index(1)]) | |||||
else: | |||||
if mutable.return_mask: | |||||
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ | |||||
"LayerChoice will not be replaced.") | |||||
# remove unused parameters | |||||
for ch, n in zip(chosen, mutable.names): | |||||
if ch == 0 and not isinstance(ch, float): | |||||
setattr(mutable, n, None) | |||||
else: | |||||
self.replace_layer_choice(mutable, global_name) | |||||
def apply_fixed_architecture(model, fixed_arc): | |||||
""" | |||||
Load architecture from `fixed_arc` and apply to model. | |||||
Parameters | |||||
---------- | |||||
model : torch.nn.Module | |||||
Model with mutables. | |||||
fixed_arc : str or dict | |||||
Path to the JSON that stores the architecture, or dict that stores the exported architecture. | |||||
Returns | |||||
------- | |||||
FixedArchitecture | |||||
Mutator that is responsible for fixes the graph. | |||||
""" | |||||
if isinstance(fixed_arc, str): | |||||
with open(fixed_arc) as f: | |||||
fixed_arc = json.load(f) | |||||
architecture = FixedArchitecture(model, fixed_arc) | |||||
architecture.reset() | |||||
# for the convenience of parameters counting | |||||
architecture.replace_layer_choice() | |||||
return architecture |
@@ -0,0 +1,206 @@ | |||||
""" | |||||
A deep MNIST classifier using convolutional layers. | |||||
This file is a modification of the official pytorch mnist example: | |||||
https://github.com/pytorch/examples/blob/master/mnist/main.py | |||||
""" | |||||
import os | |||||
import argparse | |||||
import logging | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.optim as optim | |||||
from torchvision import datasets, transforms | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
from mutator import ClassicMutator | |||||
import numpy as np | |||||
import time | |||||
import json | |||||
logger = logging.getLogger('mnist_AutoML') | |||||
class Net(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(Net, self).__init__() | |||||
# two options of conv1 | |||||
self.conv1 = LayerChoice(OrderedDict([ | |||||
("conv5x5", nn.Conv2d(1, 20, 5, 1)), | |||||
("conv3x3", nn.Conv2d(1, 20, 3, 1)) | |||||
]), key='first_conv') | |||||
# two options of mid_conv | |||||
self.mid_conv = LayerChoice([ | |||||
nn.Conv2d(20, 20, 3, 1, padding=1), | |||||
nn.Conv2d(20, 20, 5, 1, padding=2) | |||||
], key='mid_conv') | |||||
self.conv2 = nn.Conv2d(20, 50, 5, 1) | |||||
self.fc1 = nn.Linear(4*4*50, hidden_size) | |||||
self.fc2 = nn.Linear(hidden_size, 10) | |||||
# skip connection over mid_conv | |||||
self.input_switch = InputChoice(n_candidates=2, | |||||
n_chosen=1, | |||||
key='skip') | |||||
def forward(self, x): | |||||
x = F.relu(self.conv1(x)) | |||||
x = F.max_pool2d(x, 2, 2) | |||||
old_x = x | |||||
x = F.relu(self.mid_conv(x)) | |||||
zero_x = torch.zeros_like(old_x) | |||||
skip_x = self.input_switch([zero_x, old_x]) | |||||
x = torch.add(x, skip_x) | |||||
x = F.relu(self.conv2(x)) | |||||
x = F.max_pool2d(x, 2, 2) | |||||
x = x.view(-1, 4*4*50) | |||||
x = F.relu(self.fc1(x)) | |||||
x = self.fc2(x) | |||||
return F.log_softmax(x, dim=1) | |||||
def train(args, model, device, train_loader, optimizer, epoch): | |||||
model.train() | |||||
for batch_idx, (data, target) in enumerate(train_loader): | |||||
data, target = data.to(device), target.to(device) | |||||
optimizer.zero_grad() | |||||
output = model(data) | |||||
loss = F.nll_loss(output, target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
if batch_idx % args['log_interval'] == 0: | |||||
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |||||
epoch, batch_idx * len(data), len(train_loader.dataset), | |||||
100. * batch_idx / len(train_loader), loss.item())) | |||||
def test(args, model, device, test_loader): | |||||
model.eval() | |||||
test_loss = 0 | |||||
correct = 0 | |||||
with torch.no_grad(): | |||||
for data, target in test_loader: | |||||
data, target = data.to(device), target.to(device) | |||||
output = model(data) | |||||
# sum up batch loss | |||||
test_loss += F.nll_loss(output, target, reduction='sum').item() | |||||
# get the index of the max log-probability | |||||
pred = output.argmax(dim=1, keepdim=True) | |||||
correct += pred.eq(target.view_as(pred)).sum().item() | |||||
test_loss /= len(test_loader.dataset) | |||||
accuracy = 100. * correct / len(test_loader.dataset) | |||||
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |||||
test_loss, correct, len(test_loader.dataset), accuracy)) | |||||
return accuracy | |||||
def main(args): | |||||
global_result={'accuarcy':[]} | |||||
use_cuda = not args['no_cuda'] and torch.cuda.is_available() | |||||
torch.manual_seed(args['seed']) | |||||
device = torch.device("cuda" if use_cuda else "cpu") | |||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | |||||
data_dir = args['data_dir'] | |||||
train_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(data_dir, train=True, download=True, | |||||
transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=args['batch_size'], shuffle=True, **kwargs) | |||||
test_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=1000, shuffle=True, **kwargs) | |||||
hidden_size = args['hidden_size'] | |||||
model = Net(hidden_size=hidden_size).to(device) | |||||
#np.random.seed(42) | |||||
#x = np.random.rand(2,1,28,28).astype(np.float32) | |||||
#x= torch.from_numpy(x).to(device) | |||||
ClassicMutator(model,trial_id=args['trial_id'],selected_path=args["selected_space_path"],search_space_path=args["search_space_path"]) | |||||
#y=model(x) | |||||
#print(y) | |||||
optimizer = optim.SGD(model.parameters(), lr=args['lr'], | |||||
momentum=args['momentum']) | |||||
for epoch in range(1, args['epochs'] + 1): | |||||
train(args, model, device, train_loader, optimizer, epoch) | |||||
test_acc = test(args, model, device, test_loader) | |||||
print({"type":"accuracy","result":{"sequence":epoch,"category":"epoch","value":test_acc}} ) | |||||
global_result['accuarcy'].append(test_acc) | |||||
return global_result | |||||
def dump_global_result(args,global_result): | |||||
with open(args['result_path'], "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=True, indent=2) | |||||
def get_params(): | |||||
# Training settings | |||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./data', help="data directory") | |||||
parser.add_argument("--selected_space_path", type=str, | |||||
default='./selected_space.json', help="selected_space_path") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./selected_space.json', help="search_space_path") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="result_path") | |||||
parser.add_argument('--batch_size', type=int, default=64, metavar='N', | |||||
help='input batch size for training (default: 64)') | |||||
parser.add_argument("--hidden_size", type=int, default=512, metavar='N', | |||||
help='hidden layer size (default: 512)') | |||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | |||||
help='learning rate (default: 0.01)') | |||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', | |||||
help='SGD momentum (default: 0.5)') | |||||
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |||||
help='number of epochs to train (default: 10)') | |||||
parser.add_argument('--seed', type=int, default=1, metavar='S', | |||||
help='random seed (default: 1)') | |||||
parser.add_argument('--no_cuda', action='store_true', default=False, | |||||
help='disables CUDA training') | |||||
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', | |||||
help='how many batches to wait before logging training status') | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
args, _ = parser.parse_known_args() | |||||
return args | |||||
if __name__ == '__main__': | |||||
try: | |||||
start=time.time() | |||||
params = vars(get_params()) | |||||
global_result = main(params) | |||||
global_result['cost_time'] = str(time.time() - start) +'s' | |||||
dump_global_result(params,global_result) | |||||
except Exception as exception: | |||||
logger.exception(exception) | |||||
raise |
@@ -0,0 +1,52 @@ | |||||
import os | |||||
import argparse | |||||
import logging | |||||
import sys | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.optim as optim | |||||
from torchvision import datasets, transforms | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
from mutator import ClassicMutator | |||||
import numpy as np | |||||
class Net(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(Net, self).__init__() | |||||
# two options of conv1 | |||||
self.conv1 = LayerChoice(OrderedDict([ | |||||
("conv5x5", nn.Conv2d(1, 20, 5, 1)), | |||||
("conv3x3", nn.Conv2d(1, 20, 3, 1)) | |||||
]), key='conv1') | |||||
# two options of mid_conv | |||||
self.mid_conv = LayerChoice(OrderedDict([ | |||||
("conv3x3",nn.Conv2d(20, 20, 3, 1, padding=1)), | |||||
("conv5x5",nn.Conv2d(20, 20, 5, 1, padding=2)) | |||||
]), key='mid_conv') | |||||
self.conv2 = nn.Conv2d(20, 50, 5, 1) | |||||
self.fc1 = nn.Linear(4*4*50, hidden_size) | |||||
self.fc2 = nn.Linear(hidden_size, 10) | |||||
# skip connection over mid_conv | |||||
self.input_switch = InputChoice(n_candidates=2, | |||||
n_chosen=1, | |||||
key='skip') | |||||
def forward(self, x): | |||||
x = F.relu(self.conv1(x)) | |||||
x = F.max_pool2d(x, 2, 2) | |||||
old_x = x | |||||
x = F.relu(self.mid_conv(x)) | |||||
zero_x = torch.zeros_like(old_x) | |||||
skip_x = self.input_switch([zero_x, old_x]) | |||||
x = torch.add(x, skip_x) | |||||
x = F.relu(self.conv2(x)) | |||||
x = F.max_pool2d(x, 2, 2) | |||||
x = x.view(-1, 4*4*50) | |||||
x = F.relu(self.fc1(x)) | |||||
x = self.fc2(x) | |||||
return F.log_softmax(x, dim=1) |
@@ -0,0 +1,260 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import json | |||||
import logging | |||||
import os | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import torch | |||||
from pytorch.mutables import LayerChoice, InputChoice, MutableScope | |||||
from pytorch.mutator import Mutator | |||||
import numpy as np | |||||
import random | |||||
logger = logging.getLogger(__name__) | |||||
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE" | |||||
NNI_PLATFORM = "GPU" | |||||
LAYER_CHOICE = "layer_choice" | |||||
INPUT_CHOICE = "input_choice" | |||||
def get_and_apply_next_architecture(model): | |||||
""" | |||||
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful, | |||||
similar to ``get_next_parameter`` for HPO. | |||||
It will generate search space based on ``model``. | |||||
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for | |||||
generating search space for the experiment. | |||||
If not, there are still two mode, one is nni experiment mode where users | |||||
use ``nnictl`` to start an experiment. The other is standalone mode | |||||
where users directly run the trial command, this mode chooses the first | |||||
one(s) for each LayerChoice and InputChoice. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it. | |||||
""" | |||||
ClassicMutator(model) | |||||
class ClassicMutator(Mutator): | |||||
""" | |||||
This mutator is to apply the architecture chosen from tuner. | |||||
It implements the forward function of LayerChoice and InputChoice, | |||||
to only activate the chosen ones. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it. | |||||
""" | |||||
def __init__(self, model,trial_id,selected_path,search_space_path,load_selected_space=False): | |||||
super(ClassicMutator, self).__init__(model) | |||||
self._chosen_arch = {} | |||||
self._search_space = self._generate_search_space() | |||||
self.trial_id = trial_id | |||||
#if NNI_GEN_SEARCH_SPACE in os.environ: | |||||
# dry run for only generating search space | |||||
self._dump_search_space(search_space_path) | |||||
#sys.exit(0) | |||||
if load_selected_space: | |||||
logger.warning("load selected space.") | |||||
self._chosen_arch = self.load_selected_space | |||||
else: | |||||
# get chosen arch from tuner | |||||
self._chosen_arch = self.random_generate_chosen() | |||||
self._generate_selected_space(selected_path) | |||||
self.reset() | |||||
def _sample_layer_choice(self, mutable, idx, value, search_space_item): | |||||
""" | |||||
Convert layer choice to tensor representation. | |||||
Parameters | |||||
---------- | |||||
mutable : Mutable | |||||
idx : int | |||||
Number `idx` of list will be selected. | |||||
value : str | |||||
The verbose representation of the selected value. | |||||
search_space_item : list | |||||
The list for corresponding search space. | |||||
""" | |||||
# doesn't support multihot for layer choice yet | |||||
onehot_list = [False] * len(mutable) | |||||
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \ | |||||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value) | |||||
onehot_list[idx] = True | |||||
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable | |||||
def _sample_input_choice(self, mutable, idx, value, search_space_item): | |||||
""" | |||||
Convert input choice to tensor representation. | |||||
Parameters | |||||
---------- | |||||
mutable : Mutable | |||||
idx : int | |||||
Number `idx` of list will be selected. | |||||
value : str | |||||
The verbose representation of the selected value. | |||||
search_space_item : list | |||||
The list for corresponding search space. | |||||
""" | |||||
candidate_repr = search_space_item["candidates"] | |||||
multihot_list = [False] * mutable.n_candidates | |||||
for i, v in zip(idx, value): | |||||
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \ | |||||
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v) | |||||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx) | |||||
multihot_list[i] = True | |||||
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable | |||||
def sample_search(self): | |||||
""" | |||||
See :meth:`sample_final`. | |||||
""" | |||||
return self.sample_final() | |||||
def sample_final(self): | |||||
""" | |||||
Convert the chosen arch and apply it on model. | |||||
""" | |||||
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \ | |||||
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(), | |||||
self._chosen_arch.keys()) | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, (LayerChoice, InputChoice)): | |||||
assert mutable.key in self._chosen_arch, \ | |||||
"Expected '{}' in chosen arch, but not found.".format(mutable.key) | |||||
data = self._chosen_arch[mutable.key] | |||||
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \ | |||||
"'{}' is not a valid choice.".format(data) | |||||
if isinstance(mutable, LayerChoice): | |||||
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"], | |||||
self._search_space[mutable.key]["_value"]) | |||||
elif isinstance(mutable, InputChoice): | |||||
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"], | |||||
self._search_space[mutable.key]["_value"]) | |||||
elif isinstance(mutable, MutableScope): | |||||
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key) | |||||
else: | |||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) | |||||
return result | |||||
def _standalone_generate_chosen(self): | |||||
""" | |||||
Generate the chosen architecture for standalone mode, | |||||
i.e., choose the first one(s) for LayerChoice and InputChoice. | |||||
:: | |||||
{ key_name: {"_value": "conv1", | |||||
"_idx": 0} } | |||||
{ key_name: {"_value": ["in1"], | |||||
"_idx": [0]} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the chosen architecture | |||||
""" | |||||
chosen_arch = {} | |||||
for key, val in self._search_space.items(): | |||||
if val["_type"] == LAYER_CHOICE: | |||||
choices = val["_value"] | |||||
chosen_arch[key] = {"_value": choices[0], "_idx": 0} | |||||
elif val["_type"] == INPUT_CHOICE: | |||||
choices = val["_value"]["candidates"] | |||||
n_chosen = val["_value"]["n_chosen"] | |||||
if n_chosen is None: | |||||
n_chosen = len(choices) | |||||
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))} | |||||
else: | |||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val)) | |||||
return chosen_arch | |||||
def random_generate_chosen(self): | |||||
""" | |||||
Generate the chosen architecture for standalone mode, | |||||
i.e., choose the first one(s) for LayerChoice and InputChoice. | |||||
:: | |||||
{ key_name: {"_value": "conv1", | |||||
"_idx": 0} } | |||||
{ key_name: {"_value": ["in1"], | |||||
"_idx": [0]} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the chosen architecture | |||||
""" | |||||
chosen_arch = {} | |||||
np.random.seed(self.trial_id) | |||||
random.seed(self.trial_id) | |||||
for key, val in self._search_space.items(): | |||||
if val["_type"] == LAYER_CHOICE: | |||||
choices = val["_value"] | |||||
chosen_idx = np.random.randint(len(choices)) | |||||
chosen_arch[key] = {"_value": choices[chosen_idx], "_idx": chosen_idx} | |||||
elif val["_type"] == INPUT_CHOICE: | |||||
choices = val["_value"]["candidates"] | |||||
n_chosen = val["_value"]["n_chosen"] | |||||
if n_chosen is None: | |||||
n_chosen = len(choices) | |||||
chosen_idx = random.sample(list(range(n_chosen)),n_chosen) | |||||
chosen_arch[key] = {"_value": [choices[idx] for idx in chosen_idx], "_idx": chosen_idx} | |||||
else: | |||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val)) | |||||
return chosen_arch | |||||
def _generate_search_space(self): | |||||
""" | |||||
Generate search space from mutables. | |||||
Here is the search space format: | |||||
:: | |||||
{ key_name: {"_type": "layer_choice", | |||||
"_value": ["conv1", "conv2"]} } | |||||
{ key_name: {"_type": "input_choice", | |||||
"_value": {"candidates": ["in1", "in2"], | |||||
"n_chosen": 1}} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the generated search space | |||||
""" | |||||
search_space = {} | |||||
for mutable in self.mutables: | |||||
# for now we only generate flattened search space | |||||
if isinstance(mutable, LayerChoice): | |||||
key = mutable.key | |||||
val = mutable.names | |||||
search_space[key] = {"_type": LAYER_CHOICE, "_value": val} | |||||
elif isinstance(mutable, InputChoice): | |||||
key = mutable.key | |||||
search_space[key] = {"_type": INPUT_CHOICE, | |||||
"_value": {"candidates": mutable.choose_from, | |||||
"n_chosen": mutable.n_chosen}} | |||||
elif isinstance(mutable, MutableScope): | |||||
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key) | |||||
else: | |||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) | |||||
return search_space | |||||
def _dump_search_space(self, file_path): | |||||
with open(file_path, "w") as ss_file: | |||||
json.dump(self._search_space, ss_file, sort_keys=True, indent=2) | |||||
def _generate_selected_space(self,file_path): | |||||
with open(file_path, "w") as ss_file: | |||||
json.dump(self._chosen_arch, ss_file, sort_keys=True, indent=2) |
@@ -0,0 +1,64 @@ | |||||
stage1:python trainer.py --trial_id=1 --model_selected_space_path='./exp/train/2/model_selected_space.json' --search_space_path='./search_space.json' --result_path='./exp/train/2/result.json' | |||||
stage2:python selector.py --experiment_dir='./exp' --best_selected_space_path='./best_selected_space.json' | |||||
stage3:python retrainer.py --best_checkpoint_dir='experiment_id/' --best_selected_space_path='./best_selected_space.json' --result_path='result.json' | |||||
search_space.json: | |||||
{ | |||||
"first_conv": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"conv5x5", | |||||
"conv3x3" | |||||
] | |||||
}, | |||||
"mid_conv": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"0", | |||||
"1" | |||||
] | |||||
}, | |||||
"skip": { | |||||
"_type": "input_choice", | |||||
"_value": { | |||||
"candidates": [ | |||||
"", | |||||
"" | |||||
], | |||||
"n_chosen": 1 | |||||
} | |||||
} | |||||
} | |||||
selected_space.json: | |||||
{ | |||||
"first_conv": { | |||||
"_idx": 0, | |||||
"_value": "conv5x5" | |||||
}, | |||||
"mid_conv": { | |||||
"_idx": 0, | |||||
"_value": "0" | |||||
}, | |||||
"skip": { | |||||
"_idx": [ | |||||
0 | |||||
], | |||||
"_value": [ | |||||
"" | |||||
] | |||||
} | |||||
} | |||||
result.json: | |||||
{'type': 'accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 96.73815907059875}} | |||||
{'type': 'accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 97.6988382484361}} | |||||
{'type': 'accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 98.63717605004469}} | |||||
{'type': 'accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 98.72654155495978}} | |||||
{'type': 'accuracy', 'result': {'sequence': 5, 'category': 'epoch', 'value': 99.27390527256479}} | |||||
{'type': 'accuracy', 'result': {'sequence': 6, 'category': 'epoch', 'value': 99.13985701519213}} | |||||
{'type': 'accuracy', 'result': {'sequence': 7, 'category': 'epoch', 'value': 99.3632707774799}} | |||||
{'type': 'accuracy', 'result': {'sequence': 8, 'category': 'epoch', 'value': 99.4414655942806}} | |||||
{'type': 'accuracy', 'result': {'sequence': 9, 'category': 'epoch', 'value': 99.67605004468275}} | |||||
{'type': 'accuracy', 'result': {'sequence': 10, 'category': 'epoch', 'value': 99.74307417336908}} | |||||
@@ -0,0 +1,288 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import copy | |||||
import logging | |||||
import os | |||||
import argparse | |||||
import logging | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torchvision import datasets, transforms | |||||
from model import Net | |||||
from pytorch.trainer import Trainer | |||||
from pytorch.utils import AverageMeterGroup | |||||
from pytorch.utils import mkdirs | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
from fixed import apply_fixed_architecture | |||||
from mutator import ClassicMutator | |||||
from abc import ABC, abstractmethod | |||||
from pytorch.retrainer import Retrainer | |||||
import numpy as np | |||||
import time | |||||
import json | |||||
logger = logging.getLogger(__name__) | |||||
#logger.setLevel(logging.INFO) | |||||
class ClassicnasRetrainer(Retrainer): | |||||
""" | |||||
Classicnas trainer. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to be trained. | |||||
loss : callable | |||||
Receives logits and ground truth label, return a loss tensor. | |||||
metrics : callable | |||||
Receives logits and ground truth label, return a dict of metrics. | |||||
optimizer : Optimizer | |||||
The optimizer used for optimizing the model. | |||||
epochs : int | |||||
Number of epochs planned for training. | |||||
dataset_train : Dataset | |||||
Dataset for training. Will be split for training weights and architecture weights. | |||||
dataset_valid : Dataset | |||||
Dataset for testing. | |||||
mutator : ClassicMutator | |||||
Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator. | |||||
batch_size : int | |||||
Batch size. | |||||
workers : int | |||||
Workers for data loading. | |||||
device : torch.device | |||||
``torch.device("cpu")`` or ``torch.device("cuda")``. | |||||
log_frequency : int | |||||
Step count per logging. | |||||
callbacks : list of Callback | |||||
list of callbacks to trigger at events. | |||||
arc_learning_rate : float | |||||
Learning rate of architecture parameters. | |||||
unrolled : float | |||||
``True`` if using second order optimization, else first order optimization. | |||||
""" | |||||
def __init__(self, model, loss, metrics, | |||||
optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,checkpoint_dir,trial_id, | |||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, | |||||
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False): | |||||
self.model = model | |||||
self.loss = loss | |||||
self.metrics = metrics | |||||
self.optimizer = optimizer | |||||
self.epochs = epochs | |||||
self.device = device | |||||
self.batch_size = batch_size | |||||
self.checkpoint_dir =checkpoint_dir | |||||
self.train_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(dataset_train, train=True, download=True, | |||||
transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=batch_size, shuffle=True, **kwargs) | |||||
self.test_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=1000, shuffle=True, **kwargs) | |||||
self.search_space_path = search_space_path | |||||
self.selected_space_path =selected_space_path | |||||
self.trial_id = trial_id | |||||
self.result = {"accuracy": [],"cost_time": 0.} | |||||
def train(self): | |||||
# t1 = time() | |||||
# phase 1. architecture step | |||||
#print(self.model.state_dict) | |||||
apply_fixed_architecture(self.model, self.selected_space_path) | |||||
#print(self.model.state_dict) | |||||
# phase 2: child network step | |||||
for child_epoch in range(1, self.epochs + 1): | |||||
self.model.train() | |||||
for batch_idx, (data, target) in enumerate(self.train_loader): | |||||
data, target = data.to(self.device), target.to(self.device) | |||||
optimizer.zero_grad() | |||||
output = self.model(data) | |||||
loss = F.nll_loss(output, target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
if batch_idx % args['log_interval'] == 0: | |||||
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |||||
child_epoch, batch_idx * len(data), len(self.train_loader.dataset), | |||||
100. * batch_idx / len(self.train_loader), loss.item())) | |||||
test_acc = self.validate() | |||||
print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) | |||||
with open(args['result_path'], "a") as ss_file: | |||||
ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n') | |||||
self.result['accuracy'].append(test_acc) | |||||
def validate(self): | |||||
self.model.eval() | |||||
test_loss = 0 | |||||
correct = 0 | |||||
with torch.no_grad(): | |||||
for data, target in self.test_loader: | |||||
data, target = data.to(self.device), target.to(self.device) | |||||
output = self.model(data) | |||||
# sum up batch loss | |||||
test_loss += F.nll_loss(output, target, reduction='sum').item() | |||||
# get the index of the max log-probability | |||||
pred = output.argmax(dim=1, keepdim=True) | |||||
correct += pred.eq(target.view_as(pred)).sum().item() | |||||
test_loss /= len(self.test_loader.dataset) | |||||
accuracy = 100. * correct / len(self.test_loader.dataset) | |||||
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |||||
test_loss, correct, len(self.test_loader.dataset), accuracy)) | |||||
return accuracy | |||||
# | |||||
# def export(self, file): | |||||
# """ | |||||
# Override the method to export to file. | |||||
# Parameters | |||||
# ---------- | |||||
# file : str | |||||
# File path to export to. | |||||
# """ | |||||
# raise NotImplementedError | |||||
def checkpoint(self): | |||||
""" | |||||
Override to dump a checkpoint. | |||||
""" | |||||
if isinstance(self.model, nn.DataParallel): | |||||
state_dict = self.model.module.state_dict() | |||||
else: | |||||
state_dict = self.model.state_dict() | |||||
if not os.path.exists(self.checkpoint_dir): | |||||
os.makedirs(self.checkpoint_dir) | |||||
dest_path = os.path.join(self.checkpoint_dir, f"best_checkpoint_epoch{self.epochs}.pth") | |||||
logger.info("Saving model to %s", dest_path) | |||||
torch.save(state_dict, dest_path) | |||||
def dump_global_result(args,global_result): | |||||
with open(args['result_path'], "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=True, indent=2) | |||||
def get_params(): | |||||
# Training settings | |||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./data', help="data directory") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./selected_space.json', help="selected_space_path") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./selected_space.json', help="search_space_path") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="result_path") | |||||
parser.add_argument('--batch_size', type=int, default=64, metavar='N', | |||||
help='input batch size for training (default: 64)') | |||||
parser.add_argument("--hidden_size", type=int, default=512, metavar='N', | |||||
help='hidden layer size (default: 512)') | |||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | |||||
help='learning rate (default: 0.01)') | |||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', | |||||
help='SGD momentum (default: 0.5)') | |||||
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |||||
help='number of epochs to train (default: 10)') | |||||
parser.add_argument('--seed', type=int, default=1, metavar='S', | |||||
help='random seed (default: 1)') | |||||
parser.add_argument('--no_cuda', default=False, | |||||
help='disables CUDA training') | |||||
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', | |||||
help='how many batches to wait before logging training status') | |||||
parser.add_argument("--best_checkpoint_dir",type=str,default="path/to/", | |||||
help="Path for saved checkpoints. (default: %(default)s)") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
args, _ = parser.parse_known_args() | |||||
return args | |||||
if __name__ == '__main__': | |||||
try: | |||||
start=time.time() | |||||
params = vars(get_params()) | |||||
args =params | |||||
use_cuda = not args['no_cuda'] and torch.cuda.is_available() | |||||
torch.manual_seed(args['seed']) | |||||
device = torch.device("cuda" if use_cuda else "cpu") | |||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | |||||
data_dir = args['data_dir'] | |||||
hidden_size = args['hidden_size'] | |||||
model = Net(hidden_size=hidden_size).to(device) | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], | |||||
momentum=args['momentum']) | |||||
#mkdirs(args['search_space_path']) | |||||
mkdirs(args['best_selected_space_path']) | |||||
mkdirs(args['result_path']) | |||||
trainer = ClassicnasRetrainer(model, | |||||
loss=None, | |||||
metrics=None, | |||||
optimizer=optimizer, | |||||
epochs=args['epochs'], | |||||
dataset_train=data_dir, | |||||
dataset_valid=data_dir, | |||||
search_space_path = args['search_space_path'], | |||||
selected_space_path = args['best_selected_space_path'], | |||||
checkpoint_dir = args['best_checkpoint_dir'], | |||||
trial_id = args['trial_id'], | |||||
batch_size=args['batch_size'], | |||||
log_frequency=args['log_interval'], | |||||
device= device, | |||||
unrolled=None, | |||||
callbacks=None) | |||||
with open(args['result_path'], "w") as ss_file: | |||||
ss_file.write('') | |||||
trainer.train() | |||||
trainer.checkpoint() | |||||
global_result = trainer.result | |||||
global_result['cost_time'] = str(time.time() - start) +'s' | |||||
#dump_global_result(params,global_result) | |||||
except Exception as exception: | |||||
logger.exception(exception) | |||||
raise |
@@ -0,0 +1,58 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
from pytorch.utils import mkdirs | |||||
import shutil | |||||
import argparse | |||||
import os | |||||
import json | |||||
class ClassicnasSelector(Selector): | |||||
def __init__(self, args, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
self.args = args | |||||
def fit(self): | |||||
""" | |||||
only one candatite, function passed | |||||
""" | |||||
train_dir = os.path.join(self.args['experiment_dir'],'train') | |||||
max_accuracy = 0 | |||||
best_selected_space = '' | |||||
for trialId in os.listdir(train_dir): | |||||
path= os.path.join(train_dir,trialId,'result','result.json') | |||||
max_accuracy_trial = 0 | |||||
with open(path,'r') as f: | |||||
for line in f: | |||||
result_dict = json.loads(line) | |||||
accuracy = result_dict["result"]["value"] | |||||
if accuracy>max_accuracy_trial: | |||||
max_accuracy_trial=accuracy | |||||
print(max_accuracy_trial) | |||||
if max_accuracy_trial > max_accuracy: | |||||
max_accuracy = max_accuracy_trial | |||||
best_selected_space = os.path.join(train_dir,trialId,'model_selected_space','model_selected_space.json') | |||||
print('best trial id:',trialId) | |||||
shutil.copyfile(best_selected_space,self.args['best_selected_space_path']) | |||||
def get_params(): | |||||
# Training settings | |||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |||||
parser.add_argument("--experiment_dir", type=str, | |||||
default='./experiment_dir', help="data directory") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="selected_space_path") | |||||
args, _ = parser.parse_known_args() | |||||
return args | |||||
if __name__ == "__main__": | |||||
params = vars(get_params()) | |||||
args =params | |||||
mkdirs(args['best_selected_space_path']) | |||||
hpo_selector = ClassicnasSelector(args,single_candidate=False) | |||||
hpo_selector.fit() |
@@ -0,0 +1,269 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import copy | |||||
import logging | |||||
import os | |||||
import argparse | |||||
import logging | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torchvision import datasets, transforms | |||||
from model import Net | |||||
from pytorch.trainer import Trainer | |||||
from pytorch.utils import AverageMeterGroup | |||||
from pytorch.utils import mkdirs | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
from mutator import ClassicMutator | |||||
import numpy as np | |||||
import time | |||||
import json | |||||
logger = logging.getLogger(__name__) | |||||
#logger.setLevel(logging.INFO) | |||||
class ClassicnasTrainer(Trainer): | |||||
""" | |||||
Classicnas trainer. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to be trained. | |||||
loss : callable | |||||
Receives logits and ground truth label, return a loss tensor. | |||||
metrics : callable | |||||
Receives logits and ground truth label, return a dict of metrics. | |||||
optimizer : Optimizer | |||||
The optimizer used for optimizing the model. | |||||
num_epochs : int | |||||
Number of epochs planned for training. | |||||
dataset_train : Dataset | |||||
Dataset for training. Will be split for training weights and architecture weights. | |||||
dataset_valid : Dataset | |||||
Dataset for testing. | |||||
mutator : ClassicMutator | |||||
Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator. | |||||
batch_size : int | |||||
Batch size. | |||||
workers : int | |||||
Workers for data loading. | |||||
device : torch.device | |||||
``torch.device("cpu")`` or ``torch.device("cuda")``. | |||||
log_frequency : int | |||||
Step count per logging. | |||||
callbacks : list of Callback | |||||
list of callbacks to trigger at events. | |||||
arc_learning_rate : float | |||||
Learning rate of architecture parameters. | |||||
unrolled : float | |||||
``True`` if using second order optimization, else first order optimization. | |||||
""" | |||||
def __init__(self, model, loss, metrics, | |||||
optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,trial_id, | |||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, | |||||
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False): | |||||
self.model = model | |||||
self.loss = loss | |||||
self.metrics = metrics | |||||
self.optimizer = optimizer | |||||
self.epochs = epochs | |||||
self.device = device | |||||
self.batch_size = batch_size | |||||
self.train_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(dataset_train, train=True, download=False, | |||||
transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=batch_size, shuffle=True, **kwargs) | |||||
self.test_loader = torch.utils.data.DataLoader( | |||||
datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize((0.1307,), (0.3081,)) | |||||
])), | |||||
batch_size=1000, shuffle=True, **kwargs) | |||||
self.search_space_path = search_space_path | |||||
self.selected_space_path =selected_space_path | |||||
self.trial_id = trial_id | |||||
self.num_epochs = 10 | |||||
self.classicmutator=ClassicMutator(self.model,trial_id=self.trial_id,selected_path=self.selected_space_path,search_space_path=self.search_space_path) | |||||
self.result = {"accuracy": [],"cost_time": 0.} | |||||
def train_one_epoch(self, epoch): | |||||
# t1 = time() | |||||
# phase 1. architecture step | |||||
self.classicmutator.trial_id = epoch | |||||
self.classicmutator._chosen_arch=self.classicmutator.random_generate_chosen() | |||||
#print('epoch:',epoch,'\n',self.classicmutator._chosen_arch) | |||||
# phase 2: child network step | |||||
for child_epoch in range(1, self.epochs + 1): | |||||
self.model.train() | |||||
for batch_idx, (data, target) in enumerate(self.train_loader): | |||||
data, target = data.to(self.device), target.to(self.device) | |||||
optimizer.zero_grad() | |||||
output = self.model(data) | |||||
loss = F.nll_loss(output, target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
if batch_idx % args['log_interval'] == 0: | |||||
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |||||
child_epoch, batch_idx * len(data), len(self.train_loader.dataset), | |||||
100. * batch_idx / len(self.train_loader), loss.item())) | |||||
test_acc = self.validate_one_epoch(epoch) | |||||
print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) | |||||
with open(args['result_path'], "a") as ss_file: | |||||
ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n') | |||||
self.result['accuracy'].append(test_acc) | |||||
def validate_one_epoch(self, epoch): | |||||
self.model.eval() | |||||
test_loss = 0 | |||||
correct = 0 | |||||
with torch.no_grad(): | |||||
for data, target in self.test_loader: | |||||
data, target = data.to(self.device), target.to(self.device) | |||||
output = self.model(data) | |||||
# sum up batch loss | |||||
test_loss += F.nll_loss(output, target, reduction='sum').item() | |||||
# get the index of the max log-probability | |||||
pred = output.argmax(dim=1, keepdim=True) | |||||
correct += pred.eq(target.view_as(pred)).sum().item() | |||||
test_loss /= len(self.test_loader.dataset) | |||||
accuracy = 100. * correct / len(self.test_loader.dataset) | |||||
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |||||
test_loss, correct, len(self.test_loader.dataset), accuracy)) | |||||
return accuracy | |||||
def train(self): | |||||
""" | |||||
Train ``num_epochs``. | |||||
Trigger callbacks at the start and the end of each epoch. | |||||
Parameters | |||||
---------- | |||||
validate : bool | |||||
If ``true``, will do validation every epoch. | |||||
""" | |||||
for epoch in range(self.num_epochs): | |||||
# training | |||||
self.train_one_epoch(epoch) | |||||
def dump_global_result(args,global_result): | |||||
with open(args['result_path'], "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=True, indent=2) | |||||
def get_params(): | |||||
# Training settings | |||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./data', help="data directory") | |||||
parser.add_argument("--model_selected_space_path", type=str, | |||||
default='./selected_space.json', help="selected_space_path") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./selected_space.json', help="search_space_path") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./model_result.json', help="result_path") | |||||
parser.add_argument('--batch_size', type=int, default=64, metavar='N', | |||||
help='input batch size for training (default: 64)') | |||||
parser.add_argument("--hidden_size", type=int, default=512, metavar='N', | |||||
help='hidden layer size (default: 512)') | |||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | |||||
help='learning rate (default: 0.01)') | |||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', | |||||
help='SGD momentum (default: 0.5)') | |||||
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |||||
help='number of epochs to train (default: 10)') | |||||
parser.add_argument('--seed', type=int, default=1, metavar='S', | |||||
help='random seed (default: 1)') | |||||
parser.add_argument('--no_cuda', default=False, | |||||
help='disables CUDA training') | |||||
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', | |||||
help='how many batches to wait before logging training status') | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
args, _ = parser.parse_known_args() | |||||
return args | |||||
if __name__ == '__main__': | |||||
try: | |||||
start=time.time() | |||||
params = vars(get_params()) | |||||
args =params | |||||
use_cuda = not args['no_cuda'] and torch.cuda.is_available() | |||||
torch.manual_seed(args['seed']) | |||||
device = torch.device("cuda" if use_cuda else "cpu") | |||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | |||||
data_dir = args['data_dir'] | |||||
hidden_size = args['hidden_size'] | |||||
model = Net(hidden_size=hidden_size).to(device) | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], | |||||
momentum=args['momentum']) | |||||
mkdirs(args['search_space_path']) | |||||
mkdirs(args['model_selected_space_path']) | |||||
mkdirs(args['result_path']) | |||||
trainer = ClassicnasTrainer(model, | |||||
loss=None, | |||||
metrics=None, | |||||
optimizer=optimizer, | |||||
epochs=args['epochs'], | |||||
dataset_train=data_dir, | |||||
dataset_valid=data_dir, | |||||
search_space_path = args['search_space_path'], | |||||
selected_space_path = args['model_selected_space_path'], | |||||
trial_id = args['trial_id'], | |||||
batch_size=args['batch_size'], | |||||
log_frequency=args['log_interval'], | |||||
device= device, | |||||
unrolled=None, | |||||
callbacks=None) | |||||
with open(args['result_path'], "w") as ss_file: | |||||
ss_file.write('') | |||||
trainer.train_one_epoch(args['trial_id']) | |||||
#trainer.train() | |||||
global_result = trainer.result | |||||
#global_result['cost_time'] = str(time.time() - start) +'s' | |||||
#dump_global_result(params,global_result) | |||||
except Exception as exception: | |||||
logger.exception(exception) | |||||
raise |
@@ -0,0 +1,70 @@ | |||||
# Cream of the Crop: Distilling Prioritized Paths For One-Shot Neural Architecture Search | |||||
## 0x01 requirements | |||||
* Install the following requirements: | |||||
``` | |||||
future | |||||
thop | |||||
timm<0.4 | |||||
yacs | |||||
ptflops==0.6.4 | |||||
#tensorboardx | |||||
#tensorboard | |||||
#opencv-python | |||||
#torch-scope | |||||
#git+https://github.com/sovrasov/flops-counter.pytorch.git | |||||
#git+https://github.com/Tramac/torchscope.git | |||||
``` | |||||
* (required) Build and install apex to accelerate the training | |||||
(see [yuque](https://www.yuque.com/kcgyxv/ukpea3/mxz5xy)), | |||||
a little bit faster than pytorch DistributedDataParallel. | |||||
* Put the imagenet data in `./data` Using the following script: | |||||
``` | |||||
cd TADL_DIR/pytorch/cream/ | |||||
ln -s /mnt/data . | |||||
``` | |||||
## 0x02 Quick Start | |||||
* Run the following script to search an architecture. | |||||
``` | |||||
python trainer.py | |||||
``` | |||||
* Selector (deprecated) | |||||
``` | |||||
python selector.py | |||||
``` | |||||
* Train searched architectures. | |||||
> Note: exponential moving average(model_ema) is not available yet. | |||||
``` | |||||
python retrainer.py | |||||
``` | |||||
<!-- | |||||
* Test trained models. | |||||
``` | |||||
$ cp configs/test.yaml.example configs/test.yaml | |||||
$ python -m torch.distributed.launch --nproc_per_node=1 ./test.py --cfg ./configs/test.yaml | |||||
> 01/26 02:06:27 AM | [Model-14] Flops: 13.768M Params: 2.673M | |||||
> 01/26 02:06:30 AM | Training on Process 0 with 1 GPUs. | |||||
> 01/26 02:06:30 AM | Restoring model state from checkpoint... | |||||
> 01/26 02:06:30 AM | Loaded checkpoint './pretrained/14.pth.tar' (epoch 591) | |||||
> 01/26 02:06:30 AM | Loaded state_dict_ema | |||||
> 01/26 02:06:32 AM | Test_EMA: [ 0/390] Time: 1.573 (1.573) Loss: 0.9613 (0.9613) Prec@1: 82.8125 (82.8125) Prec@5: 91.4062 (91.4062) | |||||
> ... | |||||
> 01/26 02:07:50 AM | Test_EMA: [ 390/390] Time: 0.077 (0.203) Loss: 3.4356 (2.0912) Prec@1: 25.0000 (53.7640) Prec@5: 53.7500 (77.2840) | |||||
``` | |||||
--> |
@@ -0,0 +1,5 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from .trainer import CreamSupernetTrainer | |||||
from .mutator import RandomMutator |
@@ -0,0 +1,36 @@ | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from pytorch.mutator import Mutator | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
# TODO: This class is duplicate with SPOS. | |||||
class RandomMutator(Mutator): | |||||
""" | |||||
Random mutator that samples a random candidate in the search space each time ``reset()``. | |||||
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior. | |||||
""" | |||||
def sample_search(self): | |||||
""" | |||||
Sample a random candidate. | |||||
""" | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
gen_index = torch.randint(high=len(mutable), size=(1, )) | |||||
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() | |||||
elif isinstance(mutable, InputChoice): | |||||
if mutable.n_chosen is None: | |||||
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() | |||||
else: | |||||
perm = torch.randperm(mutable.n_candidates) | |||||
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] | |||||
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable | |||||
return result | |||||
def sample_final(self): | |||||
""" | |||||
Same as :meth:`sample_search`. | |||||
""" | |||||
return self.sample_search() |
@@ -0,0 +1,437 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import os | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
import logging | |||||
from copy import deepcopy | |||||
from pytorch.trainer import Trainer | |||||
from pytorch.utils import AverageMeterGroup | |||||
from .utils import accuracy, reduce_metrics | |||||
logger = logging.getLogger(__name__) | |||||
class CreamSupernetTrainer(Trainer): | |||||
""" | |||||
This trainer trains a supernet and output prioritized architectures that can be used for other tasks. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
Model with mutables. | |||||
loss : callable | |||||
Called with logits and targets. Returns a loss tensor. | |||||
val_loss : callable | |||||
Called with logits and targets for validation only. Returns a loss tensor. | |||||
optimizer : Optimizer | |||||
Optimizer that optimizes the model. | |||||
num_epochs : int | |||||
Number of epochs of training. | |||||
train_loader : iterablez | |||||
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. | |||||
valid_loader : iterablez | |||||
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. | |||||
mutator : Mutator | |||||
A mutator object that has been initialized with the model. | |||||
batch_size : int | |||||
Batch size. | |||||
log_frequency : int | |||||
Number of mini-batches to log metrics. | |||||
meta_sta_epoch : int | |||||
start epoch of using meta matching network to pick teacher architecture | |||||
update_iter : int | |||||
interval of updating meta matching networks | |||||
slices : int | |||||
batch size of mini training data in the process of training meta matching network | |||||
pool_size : int | |||||
board size | |||||
pick_method : basestring | |||||
how to pick teacher network | |||||
choice_num : int | |||||
number of operations in supernet | |||||
sta_num : int | |||||
layer number of each stage in supernet (5 stage in supernet) | |||||
acc_gap : int | |||||
maximum accuracy improvement to omit the limitation of flops | |||||
flops_dict : Dict | |||||
dictionary of each layer's operations in supernet | |||||
flops_fixed : int | |||||
flops of fixed part in supernet | |||||
local_rank : int | |||||
index of current rank | |||||
callbacks : list of Callback | |||||
Callbacks to plug into the trainer. See Callbacks. | |||||
""" | |||||
def __init__(self, selected_space, model, loss, val_loss, | |||||
optimizer, num_epochs, train_loader, valid_loader, | |||||
mutator=None, batch_size=64, log_frequency=None, | |||||
meta_sta_epoch=20, update_iter=200, slices=2, | |||||
pool_size=10, pick_method='meta', choice_num=6, | |||||
sta_num=(4, 4, 4, 4, 4), acc_gap=5, | |||||
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None, result_path=None): | |||||
assert torch.cuda.is_available() | |||||
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None, | |||||
optimizer, num_epochs, None, None, | |||||
batch_size, None, None, log_frequency, callbacks) | |||||
self.selected_space = selected_space | |||||
self.model = model | |||||
self.loss = loss | |||||
self.val_loss = val_loss | |||||
self.train_loader = train_loader | |||||
self.valid_loader = valid_loader | |||||
self.log_frequency = log_frequency | |||||
self.batch_size = batch_size | |||||
self.optimizer = optimizer | |||||
self.model = model | |||||
self.loss = loss | |||||
self.num_epochs = num_epochs | |||||
self.meta_sta_epoch = meta_sta_epoch | |||||
self.update_iter = update_iter | |||||
self.slices = slices | |||||
self.pick_method = pick_method | |||||
self.pool_size = pool_size | |||||
self.local_rank = local_rank | |||||
self.choice_num = choice_num | |||||
self.sta_num = sta_num | |||||
self.acc_gap = acc_gap | |||||
self.flops_dict = flops_dict | |||||
self.flops_fixed = flops_fixed | |||||
self.current_student_arch = None | |||||
self.current_teacher_arch = None | |||||
self.main_proc = (local_rank == 0) | |||||
self.current_epoch = 0 | |||||
self.prioritized_board = [] | |||||
self.result_path = result_path | |||||
# size of prioritized board | |||||
def _board_size(self): | |||||
return len(self.prioritized_board) | |||||
# select teacher architecture according to the logit difference | |||||
def _select_teacher(self): | |||||
self._replace_mutator_cand(self.current_student_arch) | |||||
if self.pick_method == 'top1': | |||||
meta_value, teacher_cand = 0.5, sorted( | |||||
self.prioritized_board, reverse=True)[0][3] | |||||
elif self.pick_method == 'meta': | |||||
meta_value, cand_idx, teacher_cand = -1000000000, -1, None | |||||
for now_idx, item in enumerate(self.prioritized_board): | |||||
inputx = item[4] | |||||
output = torch.nn.functional.softmax(self.model(inputx), dim=1) | |||||
weight = self.model.forward_meta(output - item[5]) | |||||
if weight > meta_value: | |||||
meta_value = weight | |||||
cand_idx = now_idx | |||||
teacher_cand = self.prioritized_board[cand_idx][3] | |||||
assert teacher_cand is not None | |||||
meta_value = torch.nn.functional.sigmoid(-weight) | |||||
else: | |||||
raise ValueError('Method Not supported') | |||||
return meta_value, teacher_cand | |||||
# check whether to update prioritized board | |||||
def _isUpdateBoard(self, prec1, flops): | |||||
if self.current_epoch <= self.meta_sta_epoch: | |||||
return False | |||||
if len(self.prioritized_board) < self.pool_size: | |||||
return True | |||||
if prec1 > self.prioritized_board[-1][1] + self.acc_gap: | |||||
return True | |||||
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]: | |||||
return True | |||||
return False | |||||
# update prioritized board | |||||
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops): | |||||
if self._isUpdateBoard(prec1, flops): | |||||
val_prec1 = prec1 | |||||
training_data = deepcopy(inputs[:self.slices].detach()) | |||||
if len(self.prioritized_board) == 0: | |||||
features = deepcopy(outputs[:self.slices].detach()) | |||||
else: | |||||
features = deepcopy(teacher_output[:self.slices].detach()) | |||||
self.prioritized_board.append( | |||||
(val_prec1, | |||||
prec1, | |||||
flops, | |||||
self.current_student_arch, | |||||
training_data, | |||||
torch.nn.functional.softmax( | |||||
features, | |||||
dim=1))) | |||||
self.prioritized_board = sorted( | |||||
self.prioritized_board, reverse=True) | |||||
if len(self.prioritized_board) > self.pool_size: | |||||
self.prioritized_board = sorted( | |||||
self.prioritized_board, reverse=True) | |||||
del self.prioritized_board[-1] | |||||
# only update student network weights | |||||
def _update_student_weights_only(self, grad_1): | |||||
for weight, grad_item in zip( | |||||
self.model.module.rand_parameters(self.current_student_arch), grad_1): | |||||
weight.grad = grad_item | |||||
torch.nn.utils.clip_grad_norm_( | |||||
self.model.module.rand_parameters(self.current_student_arch), 1) | |||||
self.optimizer.step() | |||||
for weight, grad_item in zip( | |||||
self.model.module.rand_parameters(self.current_student_arch), grad_1): | |||||
del weight.grad | |||||
# only update meta networks weights | |||||
def _update_meta_weights_only(self, teacher_cand, grad_teacher): | |||||
for weight, grad_item in zip(self.model.module.rand_parameters( | |||||
teacher_cand, self.pick_method == 'meta'), grad_teacher): | |||||
weight.grad = grad_item | |||||
# clip gradients | |||||
torch.nn.utils.clip_grad_norm_( | |||||
self.model.module.rand_parameters( | |||||
self.current_student_arch, self.pick_method == 'meta'), 1) | |||||
self.optimizer.step() | |||||
for weight, grad_item in zip(self.model.module.rand_parameters( | |||||
teacher_cand, self.pick_method == 'meta'), grad_teacher): | |||||
del weight.grad | |||||
# simulate sgd updating | |||||
def _simulate_sgd_update(self, w, g, optimizer): | |||||
return g * optimizer.param_groups[-1]['lr'] + w | |||||
# split training images into several slices | |||||
def _get_minibatch_input(self, input): | |||||
slice = self.slices | |||||
x = deepcopy(input[:slice].clone().detach()) | |||||
return x | |||||
# calculate 1st gradient of student architectures | |||||
def _calculate_1st_gradient(self, kd_loss): | |||||
self.optimizer.zero_grad() | |||||
grad = torch.autograd.grad( | |||||
kd_loss, | |||||
self.model.module.rand_parameters(self.current_student_arch), | |||||
create_graph=True) | |||||
return grad | |||||
# calculate 2nd gradient of meta networks | |||||
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight): | |||||
self.optimizer.zero_grad() | |||||
grad_student_val = torch.autograd.grad( | |||||
validation_loss, | |||||
self.model.module.rand_parameters(self.current_student_arch), | |||||
retain_graph=True) | |||||
grad_teacher = torch.autograd.grad( | |||||
students_weight[0], | |||||
self.model.module.rand_parameters( | |||||
teacher_cand, | |||||
self.pick_method == 'meta'), | |||||
grad_outputs=grad_student_val) | |||||
return grad_teacher | |||||
# forward training data | |||||
def _forward_training(self, x, meta_value): | |||||
self._replace_mutator_cand(self.current_student_arch) | |||||
output = self.model(x) | |||||
with torch.no_grad(): | |||||
self._replace_mutator_cand(self.current_teacher_arch) | |||||
teacher_output = self.model(x) | |||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1) | |||||
kd_loss = meta_value * \ | |||||
self._cross_entropy_loss_with_soft_target(output, soft_label) | |||||
return kd_loss | |||||
# calculate soft target loss | |||||
def _cross_entropy_loss_with_soft_target(self, pred, soft_target): | |||||
logsoftmax = torch.nn.LogSoftmax() | |||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) | |||||
# forward validation data | |||||
def _forward_validation(self, input, target): | |||||
slice = self.slices | |||||
x = input[slice:slice * 2].clone() | |||||
self._replace_mutator_cand(self.current_student_arch) | |||||
output_2 = self.model(x) | |||||
validation_loss = self.loss(output_2, target[slice:slice * 2]) | |||||
return validation_loss | |||||
def _isUpdateMeta(self, batch_idx): | |||||
isUpdate = True | |||||
isUpdate &= (self.current_epoch > self.meta_sta_epoch) | |||||
isUpdate &= (batch_idx > 0) | |||||
isUpdate &= (batch_idx % self.update_iter == 0) | |||||
isUpdate &= (self._board_size() > 0) | |||||
return isUpdate | |||||
def _replace_mutator_cand(self, cand): | |||||
self.mutator._cache = cand | |||||
# update meta matching networks | |||||
def _run_update(self, input, target, batch_idx): | |||||
if self._isUpdateMeta(batch_idx): | |||||
x = self._get_minibatch_input(input) | |||||
meta_value, teacher_cand = self._select_teacher() | |||||
kd_loss = self._forward_training(x, meta_value) | |||||
# calculate 1st gradient | |||||
grad_1st = self._calculate_1st_gradient(kd_loss) | |||||
# simulate updated student weights | |||||
students_weight = [ | |||||
self._simulate_sgd_update( | |||||
p, grad_item, self.optimizer) for p, grad_item in zip( | |||||
self.model.module.rand_parameters(self.current_student_arch), grad_1st)] | |||||
# update student weights | |||||
self._update_student_weights_only(grad_1st) | |||||
validation_loss = self._forward_validation(input, target) | |||||
# calculate 2nd gradient | |||||
grad_teacher = self._calculate_2nd_gradient(validation_loss, | |||||
teacher_cand, | |||||
students_weight) | |||||
# update meta matching networks | |||||
self._update_meta_weights_only(teacher_cand, grad_teacher) | |||||
# delete internal variants | |||||
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight | |||||
def _get_cand_flops(self, cand): | |||||
flops = 0 | |||||
for block_id, block in enumerate(cand): | |||||
if block == 'LayerChoice1' or block_id == 'LayerChoice23': | |||||
continue | |||||
for idx, choice in enumerate(cand[block]): | |||||
flops += self.flops_dict[block_id][idx] * (1 if choice else 0) | |||||
return flops + self.flops_fixed | |||||
def train_one_epoch(self, epoch): | |||||
self.current_epoch = epoch | |||||
meters = AverageMeterGroup() | |||||
self.steps_per_epoch = len(self.train_loader) | |||||
for step, (input_data, target) in enumerate(self.train_loader): | |||||
self.mutator.reset() | |||||
self.current_student_arch = self.mutator._cache | |||||
input_data, target = input_data.cuda(), target.cuda() | |||||
# calculate flops of current architecture | |||||
cand_flops = self._get_cand_flops(self.mutator._cache) | |||||
# update meta matching network | |||||
self._run_update(input_data, target, step) | |||||
if self._board_size() > 0: | |||||
# select teacher architecture | |||||
meta_value, teacher_cand = self._select_teacher() | |||||
self.current_teacher_arch = teacher_cand | |||||
# forward supernet | |||||
if self._board_size() == 0 or epoch <= self.meta_sta_epoch: | |||||
self._replace_mutator_cand(self.current_student_arch) | |||||
output = self.model(input_data) | |||||
loss = self.loss(output, target) | |||||
kd_loss, teacher_output, teacher_cand = None, None, None | |||||
else: | |||||
self._replace_mutator_cand(self.current_student_arch) | |||||
output = self.model(input_data) | |||||
gt_loss = self.loss(output, target) | |||||
with torch.no_grad(): | |||||
self._replace_mutator_cand(self.current_teacher_arch) | |||||
teacher_output = self.model(input_data).detach() | |||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1) | |||||
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label) | |||||
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2 | |||||
# update network | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
# update metrics | |||||
prec1, prec5 = accuracy(output, target, topk=(1, 5)) | |||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} | |||||
metrics = reduce_metrics(metrics) | |||||
meters.update(metrics) | |||||
# update prioritized board | |||||
self._update_prioritized_board(input_data, | |||||
teacher_output, | |||||
output, | |||||
metrics['prec1'], | |||||
cand_flops) | |||||
if self.main_proc and ( | |||||
step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): | |||||
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, | |||||
step + 1, len(self.train_loader), meters) | |||||
arch_list = [] | |||||
# if self.main_proc and self.num_epochs == epoch + 1: | |||||
for idx, i in enumerate(self.prioritized_board): | |||||
# logger.info("prioritized_board: No.%s %s", idx, i[:4]) | |||||
if idx == 0: | |||||
for arch in list(i[3].values()): | |||||
_ = arch.numpy() | |||||
_ = np.where(_)[0].tolist() | |||||
arch_list.append(_) | |||||
if len(arch_list) > 0: | |||||
with open(self.selected_space, "w") as f: | |||||
print("dump selected space.") | |||||
json.dump({'selected_space': arch_list}, f) | |||||
def validate_one_epoch(self, epoch): | |||||
self.model.eval() | |||||
meters = AverageMeterGroup() | |||||
with torch.no_grad(): | |||||
for step, (x, y) in enumerate(self.valid_loader): | |||||
self.mutator.reset() | |||||
logits = self.model(x) | |||||
loss = self.val_loss(logits, y) | |||||
prec1, prec5 = accuracy(logits, y, topk=(1, 5)) | |||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} | |||||
metrics = reduce_metrics(metrics) | |||||
meters.update(metrics) | |||||
if self.log_frequency is not None and step % self.log_frequency == 0: | |||||
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, | |||||
self.num_epochs, step + 1, len(self.valid_loader), meters) | |||||
# print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', | |||||
# 'value': metrics["prec1"]}}) | |||||
if self.result_path is not None: | |||||
with open(self.result_path, "a") as ss_file: | |||||
ss_file.write(json.dumps( | |||||
{'type': 'Accuracy', | |||||
'result': {'sequence': epoch, | |||||
'category': 'epoch', | |||||
'value': metrics["prec1"]}}) + '\n') |
@@ -0,0 +1,39 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import os | |||||
import torch | |||||
import torch.distributed as dist | |||||
def accuracy(output, target, topk=(1,)): | |||||
""" Computes the precision@k for the specified values of k """ | |||||
maxk = max(topk) | |||||
batch_size = target.size(0) | |||||
_, pred = output.topk(maxk, 1, True, True) | |||||
pred = pred.t() | |||||
# one-hot case | |||||
if target.ndimension() > 1: | |||||
target = target.max(1)[1] | |||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) | |||||
res = [] | |||||
for k in topk: | |||||
correct_k = correct[:k].reshape(-1).float().sum(0) | |||||
res.append(correct_k.mul_(1.0 / batch_size)) | |||||
return res | |||||
def reduce_metrics(metrics): | |||||
return {k: reduce_tensor(v).item() for k, v in metrics.items()} | |||||
def reduce_tensor(tensor): | |||||
rt = torch.sum(tensor) | |||||
# rt = tensor.clone() | |||||
# dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |||||
# rt /= float(os.environ["WORLD_SIZE"]) | |||||
return rt |
@@ -0,0 +1,123 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
from __future__ import unicode_literals | |||||
from yacs.config import CfgNode as CN | |||||
DEFAULT_CROP_PCT = 0.875 | |||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
__C = CN() | |||||
cfg = __C | |||||
__C.AUTO_RESUME = True | |||||
__C.DATA_DIR = './data/imagenet' | |||||
__C.MODEL = 'cream' | |||||
__C.RESUME_PATH = './experiments/ckps/resume.pth.tar' | |||||
__C.SAVE_PATH = './experiments/ckps/' | |||||
__C.SEED = 42 | |||||
__C.LOG_INTERVAL = 50 | |||||
__C.RECOVERY_INTERVAL = 0 | |||||
__C.WORKERS = 4 | |||||
__C.NUM_GPU = 1 | |||||
__C.SAVE_IMAGES = False | |||||
__C.AMP = False | |||||
__C.ACC_GAP = 5 | |||||
__C.OUTPUT = 'output/path/' | |||||
__C.EVAL_METRICS = 'prec1' | |||||
__C.TTA = 0 # Test or inference time augmentation | |||||
__C.LOCAL_RANK = 0 | |||||
__C.VERBOSE = False | |||||
# dataset configs | |||||
__C.DATASET = CN() | |||||
__C.DATASET.NUM_CLASSES = 1000 | |||||
__C.DATASET.IMAGE_SIZE = 224 # image patch size | |||||
__C.DATASET.INTERPOLATION = 'bilinear' # Image resize interpolation type | |||||
__C.DATASET.BATCH_SIZE = 32 # batch size | |||||
__C.DATASET.NO_PREFECHTER = False | |||||
__C.DATASET.PIN_MEM = True | |||||
__C.DATASET.VAL_BATCH_MUL = 4 | |||||
# model configs | |||||
__C.NET = CN() | |||||
__C.NET.SELECTION = 14 | |||||
__C.NET.GP = 'avg' # type of global pool ["avg", "max", "avgmax", "avgmaxc"] | |||||
__C.NET.DROPOUT_RATE = 0.0 # dropout rate | |||||
__C.NET.INPUT_ARCH = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] | |||||
# model ema parameters | |||||
__C.NET.EMA = CN() | |||||
__C.NET.EMA.USE = True | |||||
__C.NET.EMA.FORCE_CPU = False # force model ema to be tracked on CPU | |||||
__C.NET.EMA.DECAY = 0.9998 | |||||
# optimizer configs | |||||
__C.OPT = 'sgd' | |||||
__C.OPT_EPS = 1e-2 | |||||
__C.MOMENTUM = 0.9 | |||||
__C.WEIGHT_DECAY = 1e-4 | |||||
__C.OPTIMIZER = CN() | |||||
__C.OPTIMIZER.NAME = 'sgd' | |||||
__C.OPTIMIZER.MOMENTUM = 0.9 | |||||
__C.OPTIMIZER.WEIGHT_DECAY = 1e-3 | |||||
# scheduler configs | |||||
__C.SCHED = 'sgd' | |||||
__C.LR_NOISE = None | |||||
__C.LR_NOISE_PCT = 0.67 | |||||
__C.LR_NOISE_STD = 1.0 | |||||
__C.WARMUP_LR = 1e-4 | |||||
__C.MIN_LR = 1e-5 | |||||
__C.EPOCHS = 200 | |||||
__C.START_EPOCH = None | |||||
__C.DECAY_EPOCHS = 30.0 | |||||
__C.WARMUP_EPOCHS = 3 | |||||
__C.COOLDOWN_EPOCHS = 10 | |||||
__C.PATIENCE_EPOCHS = 10 | |||||
__C.DECAY_RATE = 0.1 | |||||
__C.LR = 1e-2 | |||||
__C.META_LR = 1e-4 | |||||
# data augmentation parameters | |||||
__C.AUGMENTATION = CN() | |||||
__C.AUGMENTATION.AA = 'rand-m9-mstd0.5' | |||||
__C.AUGMENTATION.COLOR_JITTER = 0.4 | |||||
__C.AUGMENTATION.RE_PROB = 0.2 # random erase prob | |||||
__C.AUGMENTATION.RE_MODE = 'pixel' # random erase mode | |||||
__C.AUGMENTATION.MIXUP = 0.0 # mixup alpha | |||||
__C.AUGMENTATION.MIXUP_OFF_EPOCH = 0 # turn off mixup after this epoch | |||||
__C.AUGMENTATION.SMOOTHING = 0.1 # label smoothing parameters | |||||
# batch norm parameters (only works with gen_efficientnet based models | |||||
# currently) | |||||
__C.BATCHNORM = CN() | |||||
__C.BATCHNORM.SYNC_BN = False | |||||
__C.BATCHNORM.BN_TF = False | |||||
__C.BATCHNORM.BN_MOMENTUM = 0.1 # batchnorm momentum override | |||||
__C.BATCHNORM.BN_EPS = 1e-5 # batchnorm eps override | |||||
# supernet training hyperparameters | |||||
__C.SUPERNET = CN() | |||||
__C.SUPERNET.UPDATE_ITER = 1300 | |||||
__C.SUPERNET.SLICE = 4 | |||||
__C.SUPERNET.POOL_SIZE = 10 | |||||
__C.SUPERNET.RESUNIT = False | |||||
__C.SUPERNET.DIL_CONV = False | |||||
__C.SUPERNET.UPDATE_2ND = True | |||||
__C.SUPERNET.FLOPS_MAXIMUM = 600 | |||||
__C.SUPERNET.FLOPS_MINIMUM = 0 | |||||
__C.SUPERNET.PICK_METHOD = 'meta' # pick teacher method | |||||
__C.SUPERNET.META_STA_EPOCH = 20 # start using meta picking method | |||||
__C.SUPERNET.HOW_TO_PROB = 'pre_prob' # sample method | |||||
__C.SUPERNET.PRE_PROB = (0.05, 0.2, 0.05, 0.5, 0.05, | |||||
0.15) # sample prob in 'pre_prob' |
@@ -0,0 +1,129 @@ | |||||
import os | |||||
import time | |||||
import timm | |||||
import torch | |||||
import torchvision | |||||
from collections import OrderedDict | |||||
from ..utils.util import AverageMeter, accuracy, reduce_tensor | |||||
def train_epoch( | |||||
epoch, model, loader, optimizer, loss_fn, args, | |||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False, | |||||
model_ema=None, logger=None, writer=None, local_rank=0): | |||||
batch_time_m = AverageMeter() | |||||
data_time_m = AverageMeter() | |||||
losses_m = AverageMeter() | |||||
prec1_m = AverageMeter() | |||||
prec5_m = AverageMeter() | |||||
model.train() | |||||
end = time.time() | |||||
last_idx = len(loader) - 1 | |||||
num_updates = epoch * len(loader) | |||||
optimizer.zero_grad() | |||||
for batch_idx, (input, target) in enumerate(loader): | |||||
last_batch = batch_idx == last_idx | |||||
data_time_m.update(time.time() - end) | |||||
input = input.cuda() | |||||
target = target.cuda() | |||||
output = model(input) | |||||
loss = loss_fn(output, target) | |||||
prec1, prec5 = accuracy(output, target, topk=(1, 5)) | |||||
if args.num_gpu > 1: | |||||
reduced_loss = reduce_tensor(loss.data, args.num_gpu) | |||||
prec1 = reduce_tensor(prec1, args.num_gpu) | |||||
prec5 = reduce_tensor(prec5, args.num_gpu) | |||||
else: | |||||
reduced_loss = loss.data | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
torch.cuda.synchronize() | |||||
losses_m.update(reduced_loss.item(), input.size(0)) | |||||
prec1_m.update(prec1.item(), output.size(0)) | |||||
prec5_m.update(prec5.item(), output.size(0)) | |||||
if model_ema is not None: | |||||
model_ema.update(model) | |||||
num_updates += 1 | |||||
batch_time_m.update(time.time() - end) | |||||
if last_batch or batch_idx % args.log_interval == 0: | |||||
lrl = [param_group['lr'] for param_group in optimizer.param_groups] | |||||
lr = sum(lrl) / len(lrl) | |||||
if local_rank == 0: | |||||
logger.info( | |||||
'Train: {} [{:>4d}/{}] ' | |||||
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' | |||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' | |||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) ' | |||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' | |||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' | |||||
'LR: {lr:.3e}' | |||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( | |||||
epoch, | |||||
batch_idx, | |||||
len(loader), | |||||
loss=losses_m, | |||||
top1=prec1_m, | |||||
top5=prec5_m, | |||||
batch_time=batch_time_m, | |||||
rate=input.size(0) * args.num_gpu / batch_time_m.val, | |||||
rate_avg=input.size(0) * args.num_gpu / batch_time_m.avg, | |||||
lr=lr, | |||||
data_time=data_time_m)) | |||||
# writer.add_scalar( | |||||
# 'Loss/train', prec1_m.avg, epoch * len(loader) + batch_idx) | |||||
# writer.add_scalar( | |||||
# 'Accuracy/train', prec1_m.avg, epoch * len(loader) + batch_idx) | |||||
# writer.add_scalar( | |||||
# 'Learning_Rate', | |||||
# optimizer.param_groups[0]['lr'], epoch * len(loader) + batch_idx) | |||||
if args.save_images and output_dir: | |||||
torchvision.utils.save_image( | |||||
input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), | |||||
padding=0, normalize=True) | |||||
if saver is not None and args.recovery_interval and ( | |||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0): | |||||
if int(timm.__version__[2]) >= 3: | |||||
saver.save_recovery( | |||||
epoch, | |||||
batch_idx=batch_idx) | |||||
else: | |||||
saver.save_recovery( | |||||
model, | |||||
optimizer, | |||||
args, | |||||
epoch, | |||||
model_ema=model_ema, | |||||
use_amp=use_amp, | |||||
batch_idx=batch_idx) | |||||
if lr_scheduler is not None: | |||||
lr_scheduler.step_update( | |||||
num_updates=num_updates, | |||||
metric=losses_m.avg) | |||||
end = time.time() | |||||
# end for | |||||
if hasattr(optimizer, 'sync_lookahead'): | |||||
optimizer.sync_lookahead() | |||||
return OrderedDict([('loss', losses_m.avg)]) |
@@ -0,0 +1,100 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import time | |||||
import torch | |||||
import json | |||||
from collections import OrderedDict | |||||
from ..utils.util import AverageMeter, accuracy, reduce_tensor | |||||
def validate(epoch, model, loader, loss_fn, args, log_suffix='', | |||||
logger=None, writer=None, local_rank=0,result_path=None): | |||||
batch_time_m = AverageMeter() | |||||
losses_m = AverageMeter() | |||||
prec1_m = AverageMeter() | |||||
prec5_m = AverageMeter() | |||||
model.eval() | |||||
end = time.time() | |||||
last_idx = len(loader) - 1 | |||||
with torch.no_grad(): | |||||
for batch_idx, (input, target) in enumerate(loader): | |||||
last_batch = batch_idx == last_idx | |||||
output = model(input) | |||||
if isinstance(output, (tuple, list)): | |||||
output = output[0] | |||||
# augmentation reduction | |||||
reduce_factor = args.tta | |||||
if reduce_factor > 1: | |||||
output = output.unfold( | |||||
0, | |||||
reduce_factor, | |||||
reduce_factor).mean( | |||||
dim=2) | |||||
target = target[0:target.size(0):reduce_factor] | |||||
loss = loss_fn(output, target) | |||||
prec1, prec5 = accuracy(output, target, topk=(1, 5)) | |||||
if args.num_gpu > 1: | |||||
reduced_loss = reduce_tensor(loss.data, args.num_gpu) | |||||
prec1 = reduce_tensor(prec1, args.num_gpu) | |||||
prec5 = reduce_tensor(prec5, args.num_gpu) | |||||
else: | |||||
reduced_loss = loss.data | |||||
torch.cuda.synchronize() | |||||
losses_m.update(reduced_loss.item(), input.size(0)) | |||||
prec1_m.update(prec1.item(), output.size(0)) | |||||
prec5_m.update(prec5.item(), output.size(0)) | |||||
batch_time_m.update(time.time() - end) | |||||
end = time.time() | |||||
if local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): | |||||
log_name = 'Test' + log_suffix | |||||
logger.info( | |||||
'{0}: [{1:>4d}/{2}] ' | |||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' | |||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' | |||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' | |||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( | |||||
log_name, batch_idx, last_idx, | |||||
batch_time=batch_time_m, loss=losses_m, | |||||
top1=prec1_m, top5=prec5_m)) | |||||
# print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch', 'value': prec1_m.val}}) | |||||
if result_path is not None: | |||||
with open(result_path, "a") as ss_file: | |||||
ss_file.write(json.dumps( | |||||
{'type': 'Accuracy', | |||||
'result': {'sequence': epoch, | |||||
'category': 'epoch', | |||||
'value': prec1_m.val}}) + '\n') | |||||
# writer.add_scalar( | |||||
# 'Loss' + log_suffix + '/vaild', | |||||
# prec1_m.avg, | |||||
# epoch * len(loader) + batch_idx) | |||||
# writer.add_scalar( | |||||
# 'Accuracy' + | |||||
# log_suffix + | |||||
# '/vaild', | |||||
# prec1_m.avg, | |||||
# epoch * | |||||
# len(loader) + | |||||
# batch_idx) | |||||
metrics = OrderedDict( | |||||
[('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) | |||||
return metrics |
@@ -0,0 +1,2 @@ | |||||
from .residual_block import get_Bottleneck, get_BasicBlock | |||||
from .inverted_residual_block import InvertedResidual |
@@ -0,0 +1,113 @@ | |||||
# This file is downloaded from | |||||
# https://github.com/rwightman/pytorch-image-models | |||||
import torch.nn as nn | |||||
from timm.models.layers import create_conv2d | |||||
from timm.models.efficientnet_blocks import make_divisible, resolve_se_args, \ | |||||
SqueezeExcite, drop_path | |||||
class InvertedResidual(nn.Module): | |||||
""" Inverted residual block w/ optional SE and CondConv routing""" | |||||
def __init__( | |||||
self, | |||||
in_chs, | |||||
out_chs, | |||||
dw_kernel_size=3, | |||||
stride=1, | |||||
dilation=1, | |||||
pad_type='', | |||||
act_layer=nn.ReLU, | |||||
noskip=False, | |||||
exp_ratio=1.0, | |||||
exp_kernel_size=1, | |||||
pw_kernel_size=1, | |||||
se_ratio=0., | |||||
se_kwargs=None, | |||||
norm_layer=nn.BatchNorm2d, | |||||
norm_kwargs=None, | |||||
conv_kwargs=None, | |||||
drop_path_rate=0.): | |||||
super(InvertedResidual, self).__init__() | |||||
norm_kwargs = norm_kwargs or {} | |||||
conv_kwargs = conv_kwargs or {} | |||||
mid_chs = make_divisible(in_chs * exp_ratio) | |||||
has_se = se_ratio is not None and se_ratio > 0. | |||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip | |||||
self.drop_path_rate = drop_path_rate | |||||
# Point-wise expansion | |||||
self.conv_pw = create_conv2d( | |||||
in_chs, | |||||
mid_chs, | |||||
exp_kernel_size, | |||||
padding=pad_type, | |||||
**conv_kwargs) | |||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs) | |||||
self.act1 = act_layer(inplace=True) | |||||
# Depth-wise convolution | |||||
self.conv_dw = create_conv2d( | |||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, | |||||
padding=pad_type, depthwise=True, **conv_kwargs) | |||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs) | |||||
self.act2 = act_layer(inplace=True) | |||||
# Squeeze-and-excitation | |||||
if has_se: | |||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) | |||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) | |||||
else: | |||||
self.se = None | |||||
# Point-wise linear projection | |||||
self.conv_pwl = create_conv2d( | |||||
mid_chs, | |||||
out_chs, | |||||
pw_kernel_size, | |||||
padding=pad_type, | |||||
**conv_kwargs) | |||||
self.bn3 = norm_layer(out_chs, **norm_kwargs) | |||||
def feature_info(self, location): | |||||
if location == 'expansion': # after SE, input to PWL | |||||
info = dict( | |||||
module='conv_pwl', | |||||
hook_type='forward_pre', | |||||
num_chs=self.conv_pwl.in_channels) | |||||
else: # location == 'bottleneck', block output | |||||
info = dict( | |||||
module='', | |||||
hook_type='', | |||||
num_chs=self.conv_pwl.out_channels) | |||||
return info | |||||
def forward(self, x): | |||||
residual = x | |||||
# Point-wise expansion | |||||
x = self.conv_pw(x) | |||||
x = self.bn1(x) | |||||
x = self.act1(x) | |||||
# Depth-wise convolution | |||||
x = self.conv_dw(x) | |||||
x = self.bn2(x) | |||||
x = self.act2(x) | |||||
# Squeeze-and-excitation | |||||
if self.se is not None: | |||||
x = self.se(x) | |||||
# Point-wise linear projection | |||||
x = self.conv_pwl(x) | |||||
x = self.bn3(x) | |||||
if self.has_residual: | |||||
if self.drop_path_rate > 0.: | |||||
x = drop_path(x, self.drop_path_rate, self.training) | |||||
x += residual | |||||
return x |
@@ -0,0 +1,105 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
def conv3x3(in_planes, out_planes, stride=1): | |||||
"3x3 convolution with padding" | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||||
padding=1, bias=True) | |||||
class BasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = nn.BatchNorm2d(planes) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class Bottleneck(nn.Module): | |||||
def __init__(self, inplanes, planes, stride=1, expansion=4): | |||||
super(Bottleneck, self).__init__() | |||||
planes = int(planes / expansion) | |||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) | |||||
self.bn1 = nn.BatchNorm2d(planes) | |||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | |||||
padding=1, bias=True) | |||||
self.bn2 = nn.BatchNorm2d(planes) | |||||
self.conv3 = nn.Conv2d( | |||||
planes, | |||||
planes * expansion, | |||||
kernel_size=1, | |||||
bias=True) | |||||
self.bn3 = nn.BatchNorm2d(planes * expansion) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.stride = stride | |||||
self.expansion = expansion | |||||
if inplanes != planes * self.expansion: | |||||
self.downsample = nn.Sequential( | |||||
nn.Conv2d(inplanes, planes * self.expansion, | |||||
kernel_size=1, stride=stride, bias=True), | |||||
nn.BatchNorm2d(planes * self.expansion), | |||||
) | |||||
else: | |||||
self.downsample = None | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
def get_Bottleneck(in_c, out_c, stride): | |||||
return Bottleneck(in_c, out_c, stride=stride) | |||||
def get_BasicBlock(in_c, out_c, stride): | |||||
return BasicBlock(in_c, out_c, stride=stride) |
@@ -0,0 +1,182 @@ | |||||
from ...utils.util import * | |||||
from collections import OrderedDict | |||||
from timm.models.efficientnet_blocks import * | |||||
class ChildNetBuilder: | |||||
def __init__( | |||||
self, | |||||
channel_multiplier=1.0, | |||||
channel_divisor=8, | |||||
channel_min=None, | |||||
output_stride=32, | |||||
pad_type='', | |||||
act_layer=None, | |||||
se_kwargs=None, | |||||
norm_layer=nn.BatchNorm2d, | |||||
norm_kwargs=None, | |||||
drop_path_rate=0., | |||||
feature_location='', | |||||
verbose=False, | |||||
logger=None): | |||||
self.channel_multiplier = channel_multiplier | |||||
self.channel_divisor = channel_divisor | |||||
self.channel_min = channel_min | |||||
self.output_stride = output_stride | |||||
self.pad_type = pad_type | |||||
self.act_layer = act_layer | |||||
self.se_kwargs = se_kwargs | |||||
self.norm_layer = norm_layer | |||||
self.norm_kwargs = norm_kwargs | |||||
self.drop_path_rate = drop_path_rate | |||||
self.feature_location = feature_location | |||||
assert feature_location in ('pre_pwl', 'post_exp', '') | |||||
self.verbose = verbose | |||||
self.in_chs = None | |||||
self.features = OrderedDict() | |||||
self.logger = logger | |||||
def _round_channels(self, chs): | |||||
return round_channels( | |||||
chs, | |||||
self.channel_multiplier, | |||||
self.channel_divisor, | |||||
self.channel_min) | |||||
def _make_block(self, ba, block_idx, block_count): | |||||
drop_path_rate = self.drop_path_rate * block_idx / block_count | |||||
bt = ba.pop('block_type') | |||||
ba['in_chs'] = self.in_chs | |||||
ba['out_chs'] = self._round_channels(ba['out_chs']) | |||||
if 'fake_in_chs' in ba and ba['fake_in_chs']: | |||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) | |||||
ba['norm_layer'] = self.norm_layer | |||||
ba['norm_kwargs'] = self.norm_kwargs | |||||
ba['pad_type'] = self.pad_type | |||||
# block act fn overrides the model default | |||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer | |||||
assert ba['act_layer'] is not None | |||||
if bt == 'ir': | |||||
ba['drop_path_rate'] = drop_path_rate | |||||
ba['se_kwargs'] = self.se_kwargs | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' InvertedResidual {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = InvertedResidual(**ba) | |||||
elif bt == 'ds' or bt == 'dsa': | |||||
ba['drop_path_rate'] = drop_path_rate | |||||
ba['se_kwargs'] = self.se_kwargs | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' DepthwiseSeparable {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = DepthwiseSeparableConv(**ba) | |||||
elif bt == 'cn': | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' ConvBnAct {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = ConvBnAct(**ba) | |||||
else: | |||||
assert False, 'Uknkown block type (%s) while building model.' % bt | |||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block | |||||
return block | |||||
def __call__(self, in_chs, model_block_args): | |||||
""" Build the blocks | |||||
Args: | |||||
in_chs: Number of input-channels passed to first block | |||||
model_block_args: A list of lists, outer list defines stages, inner | |||||
list contains strings defining block configuration(s) | |||||
Return: | |||||
List of block stacks (each stack wrapped in nn.Sequential) | |||||
""" | |||||
if self.verbose: | |||||
self.logger.info( | |||||
'Building model trunk with %d stages...' % | |||||
len(model_block_args)) | |||||
self.in_chs = in_chs | |||||
total_block_count = sum([len(x) for x in model_block_args]) | |||||
total_block_idx = 0 | |||||
current_stride = 2 | |||||
current_dilation = 1 | |||||
feature_idx = 0 | |||||
stages = [] | |||||
# outer list of block_args defines the stacks ('stages' by some | |||||
# conventions) | |||||
for stage_idx, stage_block_args in enumerate(model_block_args): | |||||
last_stack = stage_idx == (len(model_block_args) - 1) | |||||
if self.verbose: | |||||
self.logger.info('Stack: {}'.format(stage_idx)) | |||||
assert isinstance(stage_block_args, list) | |||||
blocks = [] | |||||
# each stack (stage) contains a list of block arguments | |||||
for block_idx, block_args in enumerate(stage_block_args): | |||||
last_block = block_idx == (len(stage_block_args) - 1) | |||||
extract_features = '' # No features extracted | |||||
if self.verbose: | |||||
self.logger.info(' Block: {}'.format(block_idx)) | |||||
# Sort out stride, dilation, and feature extraction details | |||||
assert block_args['stride'] in (1, 2) | |||||
if block_idx >= 1: | |||||
# only the first block in any stack can have a stride > 1 | |||||
block_args['stride'] = 1 | |||||
do_extract = False | |||||
if self.feature_location == 'pre_pwl': | |||||
if last_block: | |||||
next_stage_idx = stage_idx + 1 | |||||
if next_stage_idx >= len(model_block_args): | |||||
do_extract = True | |||||
else: | |||||
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 | |||||
elif self.feature_location == 'post_exp': | |||||
if block_args['stride'] > 1 or (last_stack and last_block): | |||||
do_extract = True | |||||
if do_extract: | |||||
extract_features = self.feature_location | |||||
next_dilation = current_dilation | |||||
if block_args['stride'] > 1: | |||||
next_output_stride = current_stride * block_args['stride'] | |||||
if next_output_stride > self.output_stride: | |||||
next_dilation = current_dilation * block_args['stride'] | |||||
block_args['stride'] = 1 | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' Converting stride to dilation to maintain output_stride=={}'.format( | |||||
self.output_stride)) | |||||
else: | |||||
current_stride = next_output_stride | |||||
block_args['dilation'] = current_dilation | |||||
if next_dilation != current_dilation: | |||||
current_dilation = next_dilation | |||||
# create the block | |||||
block = self._make_block( | |||||
block_args, total_block_idx, total_block_count) | |||||
blocks.append(block) | |||||
# stash feature module name and channel info for model feature | |||||
# extraction | |||||
if extract_features: | |||||
feature_module = block.feature_module(extract_features) | |||||
if feature_module: | |||||
feature_module = 'blocks.{}.{}.'.format( | |||||
stage_idx, block_idx) + feature_module | |||||
feature_channels = block.feature_channels(extract_features) | |||||
self.features[feature_idx] = dict( | |||||
name=feature_module, | |||||
num_chs=feature_channels | |||||
) | |||||
feature_idx += 1 | |||||
# incr global block idx (across all stacks) | |||||
total_block_idx += 1 | |||||
stages.append(nn.Sequential(*blocks)) | |||||
return stages |
@@ -0,0 +1,214 @@ | |||||
from copy import deepcopy | |||||
from ...utils.builder_util import modify_block_args | |||||
from ..blocks import get_Bottleneck, InvertedResidual | |||||
from timm.models.efficientnet_blocks import * | |||||
from pytorch.mutables import LayerChoice | |||||
class SuperNetBuilder: | |||||
""" Build Trunk Blocks | |||||
""" | |||||
def __init__( | |||||
self, | |||||
choices, | |||||
channel_multiplier=1.0, | |||||
channel_divisor=8, | |||||
channel_min=None, | |||||
output_stride=32, | |||||
pad_type='', | |||||
act_layer=None, | |||||
se_kwargs=None, | |||||
norm_layer=nn.BatchNorm2d, | |||||
norm_kwargs=None, | |||||
drop_path_rate=0., | |||||
feature_location='', | |||||
verbose=False, | |||||
resunit=False, | |||||
dil_conv=False, | |||||
logger=None): | |||||
# dict | |||||
# choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} | |||||
self.choices = [[x, y] for x in choices['kernel_size'] | |||||
for y in choices['exp_ratio']] | |||||
self.choices_num = len(self.choices) - 1 | |||||
self.channel_multiplier = channel_multiplier | |||||
self.channel_divisor = channel_divisor | |||||
self.channel_min = channel_min | |||||
self.output_stride = output_stride | |||||
self.pad_type = pad_type | |||||
self.act_layer = act_layer | |||||
self.se_kwargs = se_kwargs | |||||
self.norm_layer = norm_layer | |||||
self.norm_kwargs = norm_kwargs | |||||
self.drop_path_rate = drop_path_rate | |||||
self.feature_location = feature_location | |||||
assert feature_location in ('pre_pwl', 'post_exp', '') | |||||
self.verbose = verbose | |||||
self.resunit = resunit | |||||
self.dil_conv = dil_conv | |||||
self.logger = logger | |||||
# state updated during build, consumed by model | |||||
self.in_chs = None | |||||
def _round_channels(self, chs): | |||||
return round_channels( | |||||
chs, | |||||
self.channel_multiplier, | |||||
self.channel_divisor, | |||||
self.channel_min) | |||||
def _make_block( | |||||
self, | |||||
ba, | |||||
choice_idx, | |||||
block_idx, | |||||
block_count, | |||||
resunit=False, | |||||
dil_conv=False): | |||||
drop_path_rate = self.drop_path_rate * block_idx / block_count | |||||
bt = ba.pop('block_type') | |||||
ba['in_chs'] = self.in_chs | |||||
ba['out_chs'] = self._round_channels(ba['out_chs']) | |||||
if 'fake_in_chs' in ba and ba['fake_in_chs']: | |||||
# FIXME this is a hack to work around mismatch in origin impl input | |||||
# filters | |||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) | |||||
ba['norm_layer'] = self.norm_layer | |||||
ba['norm_kwargs'] = self.norm_kwargs | |||||
ba['pad_type'] = self.pad_type | |||||
# block act fn overrides the model default | |||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer | |||||
assert ba['act_layer'] is not None | |||||
if bt == 'ir': | |||||
ba['drop_path_rate'] = drop_path_rate | |||||
ba['se_kwargs'] = self.se_kwargs | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' InvertedResidual {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = InvertedResidual(**ba) | |||||
elif bt == 'ds' or bt == 'dsa': | |||||
ba['drop_path_rate'] = drop_path_rate | |||||
ba['se_kwargs'] = self.se_kwargs | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' DepthwiseSeparable {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = DepthwiseSeparableConv(**ba) | |||||
elif bt == 'cn': | |||||
if self.verbose: | |||||
self.logger.info( | |||||
' ConvBnAct {}, Args: {}'.format( | |||||
block_idx, str(ba))) | |||||
block = ConvBnAct(**ba) | |||||
else: | |||||
assert False, 'Uknkown block type (%s) while building model.' % bt | |||||
if choice_idx == self.choice_num - 1: | |||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block | |||||
return block | |||||
def __call__(self, in_chs, model_block_args): | |||||
""" Build the blocks | |||||
Args: | |||||
in_chs: Number of input-channels passed to first block | |||||
model_block_args: A list of lists, outer list defines stages, inner | |||||
list contains strings defining block configuration(s) | |||||
Return: | |||||
List of block stacks (each stack wrapped in nn.Sequential) | |||||
""" | |||||
if self.verbose: | |||||
self.logger.info('Building model trunk with %d stages...' % len(model_block_args)) | |||||
self.in_chs = in_chs | |||||
total_block_count = sum([len(x) for x in model_block_args]) | |||||
total_block_idx = 0 | |||||
current_stride = 2 | |||||
current_dilation = 1 | |||||
feature_idx = 0 | |||||
stages = [] | |||||
# outer list of block_args defines the stacks ('stages' by some conventions) | |||||
for stage_idx, stage_block_args in enumerate(model_block_args): | |||||
last_stack = stage_idx == (len(model_block_args) - 1) | |||||
if self.verbose: | |||||
self.logger.info('Stack: {}'.format(stage_idx)) | |||||
assert isinstance(stage_block_args, list) | |||||
# blocks = [] | |||||
# each stack (stage) contains a list of block arguments | |||||
for block_idx, block_args in enumerate(stage_block_args): | |||||
last_block = block_idx == (len(stage_block_args) - 1) | |||||
if self.verbose: | |||||
self.logger.info(' Block: {}'.format(block_idx)) | |||||
# Sort out stride, dilation, and feature extraction details | |||||
assert block_args['stride'] in (1, 2) | |||||
if block_idx >= 1: | |||||
# only the first block in any stack can have a stride > 1 | |||||
block_args['stride'] = 1 | |||||
next_dilation = current_dilation | |||||
if block_args['stride'] > 1: | |||||
next_output_stride = current_stride * block_args['stride'] | |||||
if next_output_stride > self.output_stride: | |||||
next_dilation = current_dilation * block_args['stride'] | |||||
block_args['stride'] = 1 | |||||
else: | |||||
current_stride = next_output_stride | |||||
block_args['dilation'] = current_dilation | |||||
if next_dilation != current_dilation: | |||||
current_dilation = next_dilation | |||||
if stage_idx==0 or stage_idx==6: | |||||
self.choice_num = 1 | |||||
else: | |||||
self.choice_num = len(self.choices) | |||||
if self.dil_conv: | |||||
self.choice_num += 2 | |||||
choice_blocks = [] | |||||
block_args_copy = deepcopy(block_args) | |||||
if self.choice_num == 1: | |||||
# create the block | |||||
block = self._make_block(block_args, 0, total_block_idx, total_block_count) | |||||
choice_blocks.append(block) | |||||
else: | |||||
for choice_idx, choice in enumerate(self.choices): | |||||
# create the block | |||||
block_args = deepcopy(block_args_copy) | |||||
block_args = modify_block_args(block_args, choice[0], choice[1]) | |||||
block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count) | |||||
choice_blocks.append(block) | |||||
if self.dil_conv: | |||||
block_args = deepcopy(block_args_copy) | |||||
block_args = modify_block_args(block_args, 3, 0) | |||||
block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count, | |||||
resunit=self.resunit, dil_conv=self.dil_conv) | |||||
choice_blocks.append(block) | |||||
block_args = deepcopy(block_args_copy) | |||||
block_args = modify_block_args(block_args, 5, 0) | |||||
block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count, | |||||
resunit=self.resunit, dil_conv=self.dil_conv) | |||||
choice_blocks.append(block) | |||||
if self.resunit: | |||||
block = get_Bottleneck(block.conv_pw.in_channels, | |||||
block.conv_pwl.out_channels, | |||||
block.conv_dw.stride[0]) | |||||
choice_blocks.append(block) | |||||
choice_block = LayerChoice(choice_blocks) | |||||
stages.append(choice_block) | |||||
# create the block | |||||
# block = self._make_block(block_args, total_block_idx, total_block_count) | |||||
total_block_idx += 1 # incr global block idx (across all stacks) | |||||
# stages.append(blocks) | |||||
return stages |
@@ -0,0 +1,145 @@ | |||||
from ...utils.builder_util import * | |||||
from ..builders.build_childnet import * | |||||
from timm.models.layers import SelectAdaptivePool2d | |||||
from timm.models.layers.activations import hard_sigmoid | |||||
class ChildNet(nn.Module): | |||||
def __init__( | |||||
self, | |||||
block_args, | |||||
num_classes=1000, | |||||
in_chans=3, | |||||
stem_size=16, | |||||
num_features=1280, | |||||
head_bias=True, | |||||
channel_multiplier=1.0, | |||||
pad_type='', | |||||
act_layer=nn.ReLU, | |||||
drop_rate=0., | |||||
drop_path_rate=0., | |||||
se_kwargs=None, | |||||
norm_layer=nn.BatchNorm2d, | |||||
norm_kwargs=None, | |||||
global_pool='avg', | |||||
logger=None, | |||||
verbose=False): | |||||
super(ChildNet, self).__init__() | |||||
self.num_classes = num_classes | |||||
self.num_features = num_features | |||||
self.drop_rate = drop_rate | |||||
self._in_chs = in_chans | |||||
self.logger = logger | |||||
# Stem | |||||
stem_size = round_channels(stem_size, channel_multiplier) | |||||
self.conv_stem = create_conv2d( | |||||
self._in_chs, stem_size, 3, stride=2, padding=pad_type) | |||||
self.bn1 = norm_layer(stem_size, **norm_kwargs) | |||||
self.act1 = act_layer(inplace=True) | |||||
self._in_chs = stem_size | |||||
# Middle stages (IR/ER/DS Blocks) | |||||
builder = ChildNetBuilder( | |||||
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, | |||||
norm_layer, norm_kwargs, drop_path_rate, verbose=verbose) | |||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) | |||||
# self.blocks = builder(self._in_chs, block_args) | |||||
self._in_chs = builder.in_chs | |||||
# Head + Pooling | |||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) | |||||
self.conv_head = create_conv2d( | |||||
self._in_chs, | |||||
self.num_features, | |||||
1, | |||||
padding=pad_type, | |||||
bias=head_bias) | |||||
self.act2 = act_layer(inplace=True) | |||||
# Classifier | |||||
self.classifier = nn.Linear( | |||||
self.num_features * | |||||
self.global_pool.feat_mult(), | |||||
self.num_classes) | |||||
efficientnet_init_weights(self) | |||||
def get_classifier(self): | |||||
return self.classifier | |||||
def reset_classifier(self, num_classes, global_pool='avg'): | |||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) | |||||
self.num_classes = num_classes | |||||
self.classifier = nn.Linear( | |||||
self.num_features * self.global_pool.feat_mult(), | |||||
num_classes) if self.num_classes else None | |||||
def forward_features(self, x): | |||||
# architecture = [[0], [], [], [], [], [0]] | |||||
x = self.conv_stem(x) | |||||
x = self.bn1(x) | |||||
x = self.act1(x) | |||||
x = self.blocks(x) | |||||
x = self.global_pool(x) | |||||
x = self.conv_head(x) | |||||
x = self.act2(x) | |||||
return x | |||||
def forward(self, x): | |||||
x = self.forward_features(x) | |||||
x = x.flatten(1) | |||||
if self.drop_rate > 0.: | |||||
x = F.dropout(x, p=self.drop_rate, training=self.training) | |||||
x = self.classifier(x) | |||||
return x | |||||
def gen_childnet(arch_list, arch_def, **kwargs): | |||||
# arch_list = [[0], [], [], [], [], [0]] | |||||
choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} | |||||
choices_list = [[x, y] for x in choices['kernel_size'] | |||||
for y in choices['exp_ratio']] | |||||
num_features = 1280 | |||||
# act_layer = HardSwish | |||||
act_layer = Swish | |||||
new_arch = [] | |||||
# change to child arch_def | |||||
for i, (layer_choice, layer_arch) in enumerate(zip(arch_list, arch_def)): | |||||
if len(layer_arch) == 1: | |||||
new_arch.append(layer_arch) | |||||
continue | |||||
else: | |||||
new_layer = [] | |||||
for j, (block_choice, block_arch) in enumerate( | |||||
zip(layer_choice, layer_arch)): | |||||
kernel_size, exp_ratio = choices_list[block_choice] | |||||
elements = block_arch.split('_') | |||||
block_arch = block_arch.replace( | |||||
elements[2], 'k{}'.format(str(kernel_size))) | |||||
block_arch = block_arch.replace( | |||||
elements[4], 'e{}'.format(str(exp_ratio))) | |||||
new_layer.append(block_arch) | |||||
new_arch.append(new_layer) | |||||
model_kwargs = dict( | |||||
block_args=decode_arch_def(new_arch), | |||||
num_features=num_features, | |||||
stem_size=16, | |||||
norm_kwargs=resolve_bn_args(kwargs), | |||||
act_layer=act_layer, | |||||
se_kwargs=dict( | |||||
act_layer=nn.ReLU, | |||||
gate_fn=hard_sigmoid, | |||||
reduce_mid=True, | |||||
divisor=8), | |||||
**kwargs, | |||||
) | |||||
model = ChildNet(**model_kwargs) | |||||
return model |
@@ -0,0 +1,202 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
from ...utils.builder_util import * | |||||
from ...utils.search_structure_supernet import * | |||||
from ...utils.op_by_layer_dict import flops_op_dict | |||||
from ..builders.build_supernet import * | |||||
from timm.models.layers import SelectAdaptivePool2d | |||||
from timm.models.layers.activations import hard_sigmoid | |||||
class SuperNet(nn.Module): | |||||
def __init__( | |||||
self, | |||||
block_args, | |||||
choices, | |||||
num_classes=1000, | |||||
in_chans=3, | |||||
stem_size=16, | |||||
num_features=1280, | |||||
head_bias=True, | |||||
channel_multiplier=1.0, | |||||
pad_type='', | |||||
act_layer=nn.ReLU, | |||||
drop_rate=0., | |||||
drop_path_rate=0., | |||||
slice=4, | |||||
se_kwargs=None, | |||||
norm_layer=nn.BatchNorm2d, | |||||
logger=None, | |||||
norm_kwargs=None, | |||||
global_pool='avg', | |||||
resunit=False, | |||||
dil_conv=False, | |||||
verbose=False): | |||||
super(SuperNet, self).__init__() | |||||
self.num_classes = num_classes | |||||
self.num_features = num_features | |||||
self.drop_rate = drop_rate | |||||
self._in_chs = in_chans | |||||
self.logger = logger | |||||
# Stem | |||||
stem_size = round_channels(stem_size, channel_multiplier) | |||||
self.conv_stem = create_conv2d( | |||||
self._in_chs, stem_size, 3, stride=2, padding=pad_type) | |||||
self.bn1 = norm_layer(stem_size, **norm_kwargs) | |||||
self.act1 = act_layer(inplace=True) | |||||
self._in_chs = stem_size | |||||
# Middle stages (IR/ER/DS Blocks) | |||||
builder = SuperNetBuilder( | |||||
choices, | |||||
channel_multiplier, | |||||
8, | |||||
None, | |||||
32, | |||||
pad_type, | |||||
act_layer, | |||||
se_kwargs, | |||||
norm_layer, | |||||
norm_kwargs, | |||||
drop_path_rate, | |||||
verbose=verbose, | |||||
resunit=resunit, | |||||
dil_conv=dil_conv, | |||||
logger=self.logger) | |||||
blocks = builder(self._in_chs, block_args) | |||||
self.blocks = nn.Sequential(*blocks) | |||||
self._in_chs = builder.in_chs | |||||
# Head + Pooling | |||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) | |||||
self.conv_head = create_conv2d( | |||||
self._in_chs, | |||||
self.num_features, | |||||
1, | |||||
padding=pad_type, | |||||
bias=head_bias) | |||||
self.act2 = act_layer(inplace=True) | |||||
# Classifier | |||||
self.classifier = nn.Linear( | |||||
self.num_features * | |||||
self.global_pool.feat_mult(), | |||||
self.num_classes) | |||||
self.meta_layer = nn.Linear(self.num_classes * slice, 1) | |||||
efficientnet_init_weights(self) | |||||
def get_classifier(self): | |||||
return self.classifier | |||||
def reset_classifier(self, num_classes, global_pool='avg'): | |||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) | |||||
self.num_classes = num_classes | |||||
self.classifier = nn.Linear( | |||||
self.num_features * self.global_pool.feat_mult(), | |||||
num_classes) if self.num_classes else None | |||||
def forward_features(self, x): | |||||
x = self.conv_stem(x) | |||||
x = self.bn1(x) | |||||
x = self.act1(x) | |||||
x = self.blocks(x) | |||||
x = self.global_pool(x) | |||||
x = self.conv_head(x) | |||||
x = self.act2(x) | |||||
return x | |||||
def forward(self, x): | |||||
x = self.forward_features(x) | |||||
x = x.flatten(1) | |||||
if self.drop_rate > 0.: | |||||
x = F.dropout(x, p=self.drop_rate, training=self.training) | |||||
return self.classifier(x) | |||||
def forward_meta(self, features): | |||||
return self.meta_layer(features.view(1, -1)) | |||||
def rand_parameters(self, architecture, meta=False): | |||||
for name, param in self.named_parameters(recurse=True): | |||||
if 'meta' in name and meta: | |||||
yield param | |||||
elif 'blocks' not in name and 'meta' not in name and (not meta): | |||||
yield param | |||||
if not meta: | |||||
for layer, layer_arch in zip(self.blocks, architecture): | |||||
for blocks, arch in zip(layer, layer_arch): | |||||
if arch == -1: | |||||
continue | |||||
for name, param in blocks[arch].named_parameters( | |||||
recurse=True): | |||||
yield param | |||||
class Classifier(nn.Module): | |||||
def __init__(self, num_classes=1000): | |||||
super(Classifier, self).__init__() | |||||
self.classifier = nn.Linear(num_classes, num_classes) | |||||
def forward(self, x): | |||||
return self.classifier(x) | |||||
def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs): | |||||
choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} | |||||
num_features = 1280 | |||||
# act_layer = HardSwish | |||||
act_layer = Swish | |||||
arch_def = [ | |||||
# stage 0, 112x112 in | |||||
['ds_r1_k3_s1_e1_c16_se0.25'], | |||||
# stage 1, 112x112 in | |||||
['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', | |||||
'ir_r1_k3_s1_e4_c24_se0.25'], | |||||
# stage 2, 56x56 in | |||||
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', | |||||
'ir_r1_k5_s2_e4_c40_se0.25'], | |||||
# stage 3, 28x28 in | |||||
['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', | |||||
'ir_r2_k3_s1_e4_c80_se0.25'], | |||||
# stage 4, 14x14in | |||||
['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', | |||||
'ir_r1_k3_s1_e6_c96_se0.25'], | |||||
# stage 5, 14x14in | |||||
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25', | |||||
'ir_r1_k5_s2_e6_c192_se0.25'], | |||||
# stage 6, 7x7 in | |||||
['cn_r1_k1_s1_c320_se0.25'], | |||||
] | |||||
sta_num, arch_def, resolution = search_for_layer( | |||||
flops_op_dict, arch_def, flops_minimum, flops_maximum) | |||||
if sta_num is None or arch_def is None or resolution is None: | |||||
raise ValueError('Invalid FLOPs Settings') | |||||
model_kwargs = dict( | |||||
block_args=decode_arch_def(arch_def), | |||||
choices=choices, | |||||
num_features=num_features, | |||||
stem_size=16, | |||||
norm_kwargs=resolve_bn_args(kwargs), | |||||
act_layer=act_layer, | |||||
se_kwargs=dict( | |||||
act_layer=nn.ReLU, | |||||
gate_fn=hard_sigmoid, | |||||
reduce_mid=True, | |||||
divisor=8), | |||||
**kwargs, | |||||
) | |||||
model = SuperNet(**model_kwargs) | |||||
return model, sta_num, resolution, arch_def |
@@ -0,0 +1,270 @@ | |||||
import re | |||||
import math | |||||
import torch.nn as nn | |||||
from copy import deepcopy | |||||
from timm.utils import * | |||||
from timm.models.layers.activations import Swish | |||||
from timm.models.layers import CondConv2d, get_condconv_initializer | |||||
def parse_ksize(ss): | |||||
if ss.isdigit(): | |||||
return int(ss) | |||||
else: | |||||
return [int(k) for k in ss.split('.')] | |||||
def decode_arch_def( | |||||
arch_def, | |||||
depth_multiplier=1.0, | |||||
depth_trunc='ceil', | |||||
experts_multiplier=1): | |||||
arch_args = [] | |||||
for stack_idx, block_strings in enumerate(arch_def): | |||||
assert isinstance(block_strings, list) | |||||
stack_args = [] | |||||
repeats = [] | |||||
for block_str in block_strings: | |||||
assert isinstance(block_str, str) | |||||
ba, rep = decode_block_str(block_str) | |||||
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: | |||||
ba['num_experts'] *= experts_multiplier | |||||
stack_args.append(ba) | |||||
repeats.append(rep) | |||||
arch_args.append( | |||||
scale_stage_depth( | |||||
stack_args, | |||||
repeats, | |||||
depth_multiplier, | |||||
depth_trunc)) | |||||
return arch_args | |||||
def modify_block_args(block_args, kernel_size, exp_ratio): | |||||
block_type = block_args['block_type'] | |||||
if block_type == 'cn': | |||||
block_args['kernel_size'] = kernel_size | |||||
elif block_type == 'er': | |||||
block_args['exp_kernel_size'] = kernel_size | |||||
else: | |||||
block_args['dw_kernel_size'] = kernel_size | |||||
if block_type == 'ir' or block_type == 'er': | |||||
block_args['exp_ratio'] = exp_ratio | |||||
return block_args | |||||
def decode_block_str(block_str): | |||||
""" Decode block definition string | |||||
Gets a list of block arg (dicts) through a string notation of arguments. | |||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip | |||||
All args can exist in any order with the exception of the leading string which | |||||
is assumed to indicate the block type. | |||||
leading string - block type ( | |||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) | |||||
r - number of repeat blocks, | |||||
k - kernel size, | |||||
s - strides (1-9), | |||||
e - expansion ratio, | |||||
c - output channels, | |||||
se - squeeze/excitation ratio | |||||
n - activation fn ('re', 'r6', 'hs', or 'sw') | |||||
Args: | |||||
block_str: a string representation of block arguments. | |||||
Returns: | |||||
A list of block args (dicts) | |||||
Raises: | |||||
ValueError: if the string def not properly specified (TODO) | |||||
""" | |||||
assert isinstance(block_str, str) | |||||
ops = block_str.split('_') | |||||
block_type = ops[0] # take the block type off the front | |||||
ops = ops[1:] | |||||
options = {} | |||||
noskip = False | |||||
for op in ops: | |||||
# string options being checked on individual basis, combine if they | |||||
# grow | |||||
if op == 'noskip': | |||||
noskip = True | |||||
elif op.startswith('n'): | |||||
# activation fn | |||||
key = op[0] | |||||
v = op[1:] | |||||
if v == 're': | |||||
value = nn.ReLU | |||||
elif v == 'r6': | |||||
value = nn.ReLU6 | |||||
elif v == 'sw': | |||||
value = Swish | |||||
else: | |||||
continue | |||||
options[key] = value | |||||
else: | |||||
# all numeric options | |||||
splits = re.split(r'(\d.*)', op) | |||||
if len(splits) >= 2: | |||||
key, value = splits[:2] | |||||
options[key] = value | |||||
# if act_layer is None, the model default (passed to model init) will be | |||||
# used | |||||
act_layer = options['n'] if 'n' in options else None | |||||
exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1 | |||||
pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1 | |||||
# FIXME hack to deal with in_chs issue in TPU def | |||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 | |||||
num_repeat = int(options['r']) | |||||
# each type of block has different valid arguments, fill accordingly | |||||
if block_type == 'ir': | |||||
block_args = dict( | |||||
block_type=block_type, | |||||
dw_kernel_size=parse_ksize(options['k']), | |||||
exp_kernel_size=exp_kernel_size, | |||||
pw_kernel_size=pw_kernel_size, | |||||
out_chs=int(options['c']), | |||||
exp_ratio=float(options['e']), | |||||
se_ratio=float(options['se']) if 'se' in options else None, | |||||
stride=int(options['s']), | |||||
act_layer=act_layer, | |||||
noskip=noskip, | |||||
) | |||||
if 'cc' in options: | |||||
block_args['num_experts'] = int(options['cc']) | |||||
elif block_type == 'ds' or block_type == 'dsa': | |||||
block_args = dict( | |||||
block_type=block_type, | |||||
dw_kernel_size=parse_ksize(options['k']), | |||||
pw_kernel_size=pw_kernel_size, | |||||
out_chs=int(options['c']), | |||||
se_ratio=float(options['se']) if 'se' in options else None, | |||||
stride=int(options['s']), | |||||
act_layer=act_layer, | |||||
pw_act=block_type == 'dsa', | |||||
noskip=block_type == 'dsa' or noskip, | |||||
) | |||||
elif block_type == 'cn': | |||||
block_args = dict( | |||||
block_type=block_type, | |||||
kernel_size=int(options['k']), | |||||
out_chs=int(options['c']), | |||||
stride=int(options['s']), | |||||
act_layer=act_layer, | |||||
) | |||||
else: | |||||
assert False, 'Unknown block type (%s)' % block_type | |||||
return block_args, num_repeat | |||||
def scale_stage_depth( | |||||
stack_args, | |||||
repeats, | |||||
depth_multiplier=1.0, | |||||
depth_trunc='ceil'): | |||||
""" Per-stage depth scaling | |||||
Scales the block repeats in each stage. This depth scaling impl maintains | |||||
compatibility with the EfficientNet scaling method, while allowing sensible | |||||
scaling for other models that may have multiple block arg definitions in each stage. | |||||
""" | |||||
# We scale the total repeat count for each stage, there may be multiple | |||||
# block arg defs per stage so we need to sum. | |||||
num_repeat = sum(repeats) | |||||
if depth_trunc == 'round': | |||||
# Truncating to int by rounding allows stages with few repeats to remain | |||||
# proportionally smaller for longer. This is a good choice when stage definitions | |||||
# include single repeat stages that we'd prefer to keep that way as | |||||
# long as possible | |||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) | |||||
else: | |||||
# The default for EfficientNet truncates repeats to int via 'ceil'. | |||||
# Any multiplier > 1.0 will result in an increased depth for every | |||||
# stage. | |||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) | |||||
# Proportionally distribute repeat count scaling to each block definition in the stage. | |||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled. | |||||
# The first block makes less sense to repeat in most of the arch | |||||
# definitions. | |||||
repeats_scaled = [] | |||||
for r in repeats[::-1]: | |||||
rs = max(1, round((r / num_repeat * num_repeat_scaled))) | |||||
repeats_scaled.append(rs) | |||||
num_repeat -= r | |||||
num_repeat_scaled -= rs | |||||
repeats_scaled = repeats_scaled[::-1] | |||||
# Apply the calculated scaling to each block arg in the stage | |||||
sa_scaled = [] | |||||
for ba, rep in zip(stack_args, repeats_scaled): | |||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) | |||||
return sa_scaled | |||||
def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None): | |||||
""" Weight initialization as per Tensorflow official implementations. | |||||
Args: | |||||
m (nn.Module): module to init | |||||
n (str): module name | |||||
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs | |||||
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: | |||||
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py | |||||
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py | |||||
""" | |||||
if isinstance(m, CondConv2d): | |||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
if fix_group_fanout: | |||||
fan_out //= m.groups | |||||
init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_( | |||||
0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) | |||||
init_weight_fn(m.weight) | |||||
if m.bias is not None: | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, nn.Conv2d): | |||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
if fix_group_fanout: | |||||
fan_out //= m.groups | |||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |||||
if m.bias is not None: | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, nn.BatchNorm2d): | |||||
if n in last_bn: | |||||
m.weight.data.zero_() | |||||
m.bias.data.zero_() | |||||
else: | |||||
m.weight.data.fill_(1.0) | |||||
m.bias.data.zero_() | |||||
m.weight.data.fill_(1.0) | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, nn.Linear): | |||||
fan_out = m.weight.size(0) # fan-out | |||||
fan_in = 0 | |||||
if 'routing_fn' in n: | |||||
fan_in = m.weight.size(1) | |||||
init_range = 1.0 / math.sqrt(fan_in + fan_out) | |||||
m.weight.data.uniform_(-init_range, init_range) | |||||
m.bias.data.zero_() | |||||
def efficientnet_init_weights( | |||||
model: nn.Module, | |||||
init_fn=None, | |||||
zero_gamma=False): | |||||
last_bn = [] | |||||
if zero_gamma: | |||||
prev_n = '' | |||||
for n, m in model.named_modules(): | |||||
if isinstance(m, nn.BatchNorm2d): | |||||
if ''.join(prev_n.split('.')[:-1]) != ''.join(n.split('.')[:-1]): | |||||
last_bn.append(prev_n) | |||||
prev_n = n | |||||
last_bn.append(prev_n) | |||||
init_fn = init_fn or init_weight_goog | |||||
for n, m in model.named_modules(): | |||||
init_fn(m, n, last_bn=last_bn) | |||||
init_fn(m, n, last_bn=last_bn) |
@@ -0,0 +1,79 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import torch | |||||
from ptflops import get_model_complexity_info | |||||
class FlopsEst(object): | |||||
def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'): | |||||
self.block_num = len(model.blocks) | |||||
self.choice_num = len(model.blocks[0]) | |||||
self.flops_dict = {} | |||||
self.params_dict = {} | |||||
if device == 'cpu': | |||||
model = model.cpu() | |||||
else: | |||||
model = model.cuda() | |||||
self.params_fixed = 0 | |||||
self.flops_fixed = 0 | |||||
input = torch.randn(input_shape) | |||||
flops, params = get_model_complexity_info( | |||||
model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False) | |||||
self.params_fixed += params / 1e6 | |||||
self.flops_fixed += flops / 1e6 | |||||
input = model.conv_stem(input) | |||||
for block_id, block in enumerate(model.blocks): | |||||
self.flops_dict[block_id] = {} | |||||
self.params_dict[block_id] = {} | |||||
for module_id, module in enumerate(block): | |||||
flops, params = get_model_complexity_info(module, tuple( | |||||
input.shape[1:]), as_strings=False, print_per_layer_stat=False) | |||||
# Flops(M) | |||||
self.flops_dict[block_id][module_id] = flops / 1e6 | |||||
# Params(M) | |||||
self.params_dict[block_id][module_id] = params / 1e6 | |||||
input = module(input) | |||||
# conv_last | |||||
flops, params = get_model_complexity_info(model.global_pool, tuple( | |||||
input.shape[1:]), as_strings=False, print_per_layer_stat=False) | |||||
self.params_fixed += params / 1e6 | |||||
self.flops_fixed += flops / 1e6 | |||||
input = model.global_pool(input) | |||||
# globalpool | |||||
flops, params = get_model_complexity_info(model.conv_head, tuple( | |||||
input.shape[1:]), as_strings=False, print_per_layer_stat=False) | |||||
self.params_fixed += params / 1e6 | |||||
self.flops_fixed += flops / 1e6 | |||||
# return params (M) | |||||
def get_params(self, arch): | |||||
params = 0 | |||||
for block_id, block in enumerate(arch): | |||||
if block == -1: | |||||
continue | |||||
params += self.params_dict[block_id][block] | |||||
return params + self.params_fixed | |||||
# return flops (M) | |||||
def get_flops(self, arch): | |||||
flops = 0 | |||||
for block_id, block in enumerate(arch): | |||||
if block == 'LayerChoice1' or block_id == 'LayerChoice23': | |||||
continue | |||||
for idx, choice in enumerate(arch[block]): | |||||
flops += self.flops_dict[block_id][idx] * (1 if choice else 0) | |||||
return flops + self.flops_fixed |
@@ -0,0 +1,42 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
# This dictionary is generated from calculating each operation of each layer to quickly search for layers. | |||||
# flops_op_dict[which_stage][which_operation] = | |||||
# (flops_of_operation_with_stride1, flops_of_operation_with_stride2) | |||||
flops_op_dict = {} | |||||
for i in range(5): | |||||
flops_op_dict[i] = {} | |||||
flops_op_dict[0][0] = (21.828704, 18.820752) | |||||
flops_op_dict[0][1] = (32.669328, 28.16048) | |||||
flops_op_dict[0][2] = (25.039968, 23.637648) | |||||
flops_op_dict[0][3] = (37.486224, 35.385824) | |||||
flops_op_dict[0][4] = (29.856864, 30.862992) | |||||
flops_op_dict[0][5] = (44.711568, 46.22384) | |||||
flops_op_dict[1][0] = (11.808656, 11.86712) | |||||
flops_op_dict[1][1] = (17.68624, 17.780848) | |||||
flops_op_dict[1][2] = (13.01288, 13.87416) | |||||
flops_op_dict[1][3] = (19.492576, 20.791408) | |||||
flops_op_dict[1][4] = (14.819216, 16.88472) | |||||
flops_op_dict[1][5] = (22.20208, 25.307248) | |||||
flops_op_dict[2][0] = (8.198, 10.99632) | |||||
flops_op_dict[2][1] = (12.292848, 16.5172) | |||||
flops_op_dict[2][2] = (8.69976, 11.99984) | |||||
flops_op_dict[2][3] = (13.045488, 18.02248) | |||||
flops_op_dict[2][4] = (9.4524, 13.50512) | |||||
flops_op_dict[2][5] = (14.174448, 20.2804) | |||||
flops_op_dict[3][0] = (12.006112, 15.61632) | |||||
flops_op_dict[3][1] = (18.028752, 23.46096) | |||||
flops_op_dict[3][2] = (13.009632, 16.820544) | |||||
flops_op_dict[3][3] = (19.534032, 25.267296) | |||||
flops_op_dict[3][4] = (14.514912, 18.62688) | |||||
flops_op_dict[3][5] = (21.791952, 27.9768) | |||||
flops_op_dict[4][0] = (11.307456, 15.292416) | |||||
flops_op_dict[4][1] = (17.007072, 23.1504) | |||||
flops_op_dict[4][2] = (11.608512, 15.894528) | |||||
flops_op_dict[4][3] = (17.458656, 24.053568) | |||||
flops_op_dict[4][4] = (12.060096, 16.797696) | |||||
flops_op_dict[4][5] = (18.136032, 25.40832) |
@@ -0,0 +1,47 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum): | |||||
sta_num = [1, 1, 1, 1, 1] | |||||
order = [2, 3, 4, 1, 0, 2, 3, 4, 1, 0] | |||||
limits = [3, 3, 3, 2, 2, 4, 4, 4, 4, 4] | |||||
size_factor = 224 // 32 | |||||
base_min_flops = sum([flops_op_dict[i][0][0] for i in range(5)]) | |||||
base_max_flops = sum([flops_op_dict[i][5][0] for i in range(5)]) | |||||
if base_min_flops > flops_maximum: | |||||
while base_min_flops > flops_maximum and size_factor >= 2: | |||||
size_factor = size_factor - 1 | |||||
flops_minimum = flops_minimum * (7. / size_factor) | |||||
flops_maximum = flops_maximum * (7. / size_factor) | |||||
if size_factor < 2: | |||||
return None, None, None | |||||
elif base_max_flops < flops_minimum: | |||||
cur_ptr = 0 | |||||
while base_max_flops < flops_minimum and cur_ptr <= 9: | |||||
if sta_num[order[cur_ptr]] >= limits[cur_ptr]: | |||||
cur_ptr += 1 | |||||
continue | |||||
base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1] | |||||
sta_num[order[cur_ptr]] += 1 | |||||
if cur_ptr > 7 and base_max_flops < flops_minimum: | |||||
return None, None, None | |||||
cur_ptr = 0 | |||||
while cur_ptr <= 9: | |||||
if sta_num[order[cur_ptr]] >= limits[cur_ptr]: | |||||
cur_ptr += 1 | |||||
continue | |||||
base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1] | |||||
if base_max_flops <= flops_maximum: | |||||
sta_num[order[cur_ptr]] += 1 | |||||
else: | |||||
break | |||||
arch_def = [item[:i] for i, item in zip([1] + sta_num + [1], arch_def)] | |||||
# print(arch_def) | |||||
return sta_num, arch_def, size_factor * 32 |
@@ -0,0 +1,178 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import sys | |||||
import torch | |||||
import logging | |||||
import argparse | |||||
import torch | |||||
import torch.nn as nn | |||||
from copy import deepcopy | |||||
from torch import optim as optim | |||||
from thop import profile, clever_format | |||||
from timm.utils import * | |||||
from ..config import cfg | |||||
def get_path_acc(model, path, val_loader, args, val_iters=50): | |||||
prec1_m = AverageMeter() | |||||
prec5_m = AverageMeter() | |||||
with torch.no_grad(): | |||||
for batch_idx, (input, target) in enumerate(val_loader): | |||||
if batch_idx >= val_iters: | |||||
break | |||||
if not args.prefetcher: | |||||
input = input.cuda() | |||||
target = target.cuda() | |||||
output = model(input, path) | |||||
if isinstance(output, (tuple, list)): | |||||
output = output[0] | |||||
# augmentation reduction | |||||
reduce_factor = args.tta | |||||
if reduce_factor > 1: | |||||
output = output.unfold( | |||||
0, | |||||
reduce_factor, | |||||
reduce_factor).mean( | |||||
dim=2) | |||||
target = target[0:target.size(0):reduce_factor] | |||||
prec1, prec5 = accuracy(output, target, topk=(1, 5)) | |||||
torch.cuda.synchronize() | |||||
prec1_m.update(prec1.item(), output.size(0)) | |||||
prec5_m.update(prec5.item(), output.size(0)) | |||||
return (prec1_m.avg, prec5_m.avg) | |||||
def get_logger(file_path): | |||||
""" Make python logger """ | |||||
log_format = '%(asctime)s | %(message)s' | |||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO, | |||||
format=log_format, datefmt='%m/%d %I:%M:%S %p') | |||||
logger = logging.getLogger() | |||||
logger.setLevel(logging.INFO) | |||||
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') | |||||
file_handler = logging.FileHandler(file_path) | |||||
file_handler.setFormatter(formatter) | |||||
logger.addHandler(file_handler) | |||||
return logger | |||||
def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()): | |||||
decay = [] | |||||
no_decay = [] | |||||
meta_layer_no_decay = [] | |||||
meta_layer_decay = [] | |||||
for name, param in model.named_parameters(): | |||||
if not param.requires_grad: | |||||
continue # frozen weights | |||||
if len(param.shape) == 1 or name.endswith( | |||||
".bias") or name in skip_list: | |||||
if 'meta_layer' in name: | |||||
meta_layer_no_decay.append(param) | |||||
else: | |||||
no_decay.append(param) | |||||
else: | |||||
if 'meta_layer' in name: | |||||
meta_layer_decay.append(param) | |||||
else: | |||||
decay.append(param) | |||||
return [ | |||||
{'params': no_decay, 'weight_decay': 0., 'lr': args.lr}, | |||||
{'params': decay, 'weight_decay': weight_decay, 'lr': args.lr}, | |||||
{'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr}, | |||||
{'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr}, | |||||
] | |||||
def create_optimizer_supernet(args, model, has_apex=False, filter_bias_and_bn=True): | |||||
weight_decay = args.weight_decay | |||||
if 'adamw' == args.opt or 'radam' == args.opt : | |||||
weight_decay /= args.lr | |||||
if weight_decay and filter_bias_and_bn: | |||||
parameters = add_weight_decay_supernet(model, args, weight_decay) | |||||
weight_decay = 0. | |||||
else: | |||||
parameters = model.parameters() | |||||
if 'fused' == args.opt: | |||||
assert has_apex and torch.cuda.is_available( | |||||
), 'APEX and CUDA required for fused optimizers' | |||||
if args.opt == 'sgd' or args.opt == 'nesterov': | |||||
optimizer = optim.SGD( | |||||
parameters, | |||||
momentum=args.momentum, | |||||
weight_decay=weight_decay, | |||||
nesterov=True) | |||||
elif args.opt == 'momentum': | |||||
optimizer = optim.SGD( | |||||
parameters, | |||||
momentum=args.momentum, | |||||
weight_decay=weight_decay, | |||||
nesterov=False) | |||||
elif args.opt == 'adam': | |||||
optimizer = optim.Adam( | |||||
parameters, weight_decay=weight_decay, eps=args.opt_eps) | |||||
else: | |||||
assert False and "Invalid optimizer" | |||||
raise ValueError | |||||
return optimizer | |||||
def convert_lowercase(cfg): | |||||
keys = cfg.keys() | |||||
lowercase_keys = [key.lower() for key in keys] | |||||
values = [cfg.get(key) for key in keys] | |||||
for lowercase_key, value in zip(lowercase_keys, values): | |||||
cfg.setdefault(lowercase_key, value) | |||||
return cfg | |||||
# | |||||
# def parse_config_args(exp_name): | |||||
# parser = argparse.ArgumentParser(description=exp_name) | |||||
# parser.add_argument( | |||||
# '--cfg', | |||||
# type=str, | |||||
# default='../experiments/workspace/retrain/retrain.yaml', | |||||
# help='configuration of cream') | |||||
# parser.add_argument('--local_rank', type=int, default=0, | |||||
# help='local_rank') | |||||
# args = parser.parse_args() | |||||
# | |||||
# cfg.merge_from_file(args.cfg) | |||||
# converted_cfg = convert_lowercase(cfg) | |||||
# | |||||
# return args, converted_cfg | |||||
def get_model_flops_params(model, input_size=(1, 3, 224, 224)): | |||||
input = torch.randn(input_size) | |||||
macs, params = profile(deepcopy(model), inputs=(input,), verbose=False) | |||||
macs, params = clever_format([macs, params], "%.3f") | |||||
return macs, params | |||||
def cross_entropy_loss_with_soft_target(pred, soft_target): | |||||
logsoftmax = nn.LogSoftmax() | |||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) | |||||
def create_supernet_scheduler(optimizer, epochs, num_gpu, batch_size, lr): | |||||
ITERS = epochs * \ | |||||
(1280000 / (num_gpu * batch_size)) | |||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: ( | |||||
lr - step / ITERS) if step <= ITERS else 0, last_epoch=-1) | |||||
return lr_scheduler, epochs |
@@ -0,0 +1,6 @@ | |||||
## Pretrained models | |||||
The official 14M/43M/114M/287M/481M/604M pretrained models in | |||||
[google drive](https://drive.google.com/drive/folders/1CQjyBryZ4F20Rutj7coF8HWFcedApUn2) or | |||||
[Models-Baidu Disk (password: wqw6)](https://pan.baidu.com/s/1TqQNm2s14oEdyNPimw3T9g). | |||||
@@ -0,0 +1,462 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import sys | |||||
sys.path.append('../..') | |||||
import os | |||||
import json | |||||
import time | |||||
import timm | |||||
import torch | |||||
import numpy as np | |||||
import torch.nn as nn | |||||
from argparse import ArgumentParser | |||||
# from torch.utils.tensorboard import SummaryWriter | |||||
# import timm packages | |||||
from timm.optim import create_optimizer | |||||
from timm.models import resume_checkpoint | |||||
from timm.scheduler import create_scheduler | |||||
from timm.data import Dataset, create_loader | |||||
from timm.utils import CheckpointSaver, ModelEma, update_summary | |||||
from timm.loss import LabelSmoothingCrossEntropy | |||||
# import apex as distributed package | |||||
try: | |||||
from apex import amp | |||||
from apex.parallel import DistributedDataParallel as DDP | |||||
from apex.parallel import convert_syncbn_model | |||||
HAS_APEX = True | |||||
except ImportError as e: | |||||
print(e) | |||||
from torch.nn.parallel import DistributedDataParallel as DDP | |||||
HAS_APEX = False | |||||
# import models and training functions | |||||
from pytorch.utils import mkdirs, save_best_checkpoint, str2bool | |||||
from lib.core.test import validate | |||||
from lib.core.retrain import train_epoch | |||||
from lib.models.structures.childnet import gen_childnet | |||||
from lib.utils.util import get_logger, get_model_flops_params | |||||
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||||
def parse_args(): | |||||
"""See lib.utils.config""" | |||||
parser = ArgumentParser() | |||||
# path | |||||
parser.add_argument("--best_checkpoint_dir", type=str, default='./output/best_checkpoint/') | |||||
parser.add_argument("--checkpoint_dir", type=str, default='./output/checkpoints/') | |||||
parser.add_argument("--data_dir", type=str, default='./data') | |||||
parser.add_argument("--experiment_dir", type=str, default='./') | |||||
parser.add_argument("--model_name", type=str, default='retrainer') | |||||
parser.add_argument("--log_path", type=str, default='output/log') | |||||
parser.add_argument("--result_path", type=str, default='output/result.json') | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='output/selected_space.json') | |||||
# int | |||||
parser.add_argument("--acc_gap", type=int, default=5) | |||||
parser.add_argument("--batch_size", type=int, default=32) | |||||
parser.add_argument("--cooldown_epochs", type=int, default=10) | |||||
parser.add_argument("--decay_epochs", type=int, default=10) | |||||
parser.add_argument("--epochs", type=int, default=200) | |||||
parser.add_argument("--flops_minimum", type=int, default=0) | |||||
parser.add_argument("--flops_maximum", type=int, default=200) | |||||
parser.add_argument("--image_size", type=int, default=224) | |||||
parser.add_argument("--local_rank", type=int, default=0) | |||||
parser.add_argument("--log_interval", type=int, default=50) | |||||
parser.add_argument("--meta_sta_epoch", type=int, default=20) | |||||
parser.add_argument("--num_classes", type=int, default=1000) | |||||
parser.add_argument("--num_gpu", type=int, default=1) | |||||
parser.add_argument("--parience_epochs", type=int, default=10) | |||||
parser.add_argument("--pool_size", type=int, default=10) | |||||
parser.add_argument("--recovery_interval", type=int, default=10) | |||||
parser.add_argument("--trial_id", type=int, default=42) | |||||
parser.add_argument("--selection", type=int, default=-1) | |||||
parser.add_argument("--slice_num", type=int, default=4) | |||||
parser.add_argument("--tta", type=int, default=0) | |||||
parser.add_argument("--update_iter", type=int, default=1300) | |||||
parser.add_argument("--val_batch_mul", type=int, default=4) | |||||
parser.add_argument("--warmup_epochs", type=int, default=3) | |||||
parser.add_argument("--workers", type=int, default=4) | |||||
# float | |||||
parser.add_argument("--color_jitter", type=float, default=0.4) | |||||
parser.add_argument("--decay_rate", type=float, default=0.1) | |||||
parser.add_argument("--dropout_rate", type=float, default=0.0) | |||||
parser.add_argument("--ema_decay", type=float, default=0.998) | |||||
parser.add_argument("--lr", type=float, default=1e-2) | |||||
parser.add_argument("--meta_lr", type=float, default=1e-4) | |||||
parser.add_argument("--re_prob", type=float, default=0.2) | |||||
parser.add_argument("--opt_eps", type=float, default=1e-2) | |||||
parser.add_argument("--momentum", type=float, default=0.9) | |||||
parser.add_argument("--min_lr", type=float, default=1e-5) | |||||
parser.add_argument("--smoothing", type=float, default=0.1) | |||||
parser.add_argument("--weight_decay", type=float, default=1e-4) | |||||
parser.add_argument("--warmup_lr", type=float, default=1e-4) | |||||
# bool | |||||
parser.add_argument("--auto_resume", type=str2bool, default='False') | |||||
parser.add_argument("--dil_conv", type=str2bool, default='False') | |||||
parser.add_argument("--ema_cpu", type=str2bool, default='False') | |||||
parser.add_argument("--pin_mem", type=str2bool, default='True') | |||||
parser.add_argument("--resunit", type=str2bool, default='False') | |||||
parser.add_argument("--save_images", type=str2bool, default='False') | |||||
parser.add_argument("--sync_bn", type=str2bool, default='False') | |||||
parser.add_argument("--use_ema", type=str2bool, default='False') | |||||
parser.add_argument("--verbose", type=str2bool, default='False') | |||||
# str | |||||
parser.add_argument("--aa", type=str, default='rand-m9-mstd0.5') | |||||
parser.add_argument("--eval_metrics", type=str, default='prec1') | |||||
# gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"] | |||||
parser.add_argument("--gp", type=str, default='avg') | |||||
parser.add_argument("--interpolation", type=str, default='bilinear') | |||||
parser.add_argument("--opt", type=str, default='sgd') | |||||
parser.add_argument("--pick_method", type=str, default='meta') | |||||
parser.add_argument("--re_mode", type=str, default='pixel') | |||||
parser.add_argument("--sched", type=str, default='sgd') | |||||
args = parser.parse_args() | |||||
args.sync_bn = False | |||||
args.verbose = False | |||||
args.data_dir = args.data_dir + "/imagenet" | |||||
return args | |||||
def main(): | |||||
args = parse_args() | |||||
mkdirs(args.checkpoint_dir + "/", | |||||
args.experiment_dir, | |||||
args.best_selected_space_path, | |||||
args.result_path) | |||||
with open(args.result_path, "w") as ss_file: | |||||
ss_file.write('') | |||||
if len(args.checkpoint_dir > 1): | |||||
mkdirs(args.best_checkpoint_dir + "/") | |||||
args.checkpoint_dir = os.path.join( | |||||
args.checkpoint_dir, | |||||
"{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) | |||||
) | |||||
if not os.path.exists(args.checkpoint_dir): | |||||
os.mkdir(args.checkpoint_dir) | |||||
# resolve logging | |||||
if args.local_rank == 0: | |||||
logger = get_logger(args.log_path) | |||||
writer = None # SummaryWriter(os.path.join(output_dir, 'runs')) | |||||
else: | |||||
writer, logger = None, None | |||||
# retrain model selection | |||||
if args.selection == -1: | |||||
if os.path.exists(args.best_selected_space_path): | |||||
with open(args.best_selected_space_path, "r") as f: | |||||
arch_list = json.load(f)['selected_space'] | |||||
else: | |||||
args.selection = 14 | |||||
logger.warning("args.best_selected_space_path is not exist. Set selection to 14.") | |||||
if args.selection == 481: | |||||
arch_list = [ | |||||
[0], [ | |||||
3, 4, 3, 1], [ | |||||
3, 2, 3, 0], [ | |||||
3, 3, 3, 1], [ | |||||
3, 3, 3, 3], [ | |||||
3, 3, 3, 3], [0]] | |||||
args.image_size = 224 | |||||
elif args.selection == 43: | |||||
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] | |||||
args.image_size = 96 | |||||
elif args.selection == 14: | |||||
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] | |||||
args.image_size = 64 | |||||
elif args.selection == 112: | |||||
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] | |||||
args.image_size = 160 | |||||
elif args.selection == 287: | |||||
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] | |||||
args.image_size = 224 | |||||
elif args.selection == 604: | |||||
arch_list = [ | |||||
[0], [ | |||||
3, 3, 2, 3, 3], [ | |||||
3, 2, 3, 2, 3], [ | |||||
3, 2, 3, 2, 3], [ | |||||
3, 3, 2, 2, 3, 3], [ | |||||
3, 3, 2, 3, 3, 3], [0]] | |||||
args.image_size = 224 | |||||
elif args.selection == -1: | |||||
args.image_size = 224 | |||||
else: | |||||
raise ValueError("Model Retrain Selection is not Supported!") | |||||
print(arch_list) | |||||
# define childnet architecture from arch_list | |||||
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] | |||||
# TODO: this param from NNI is different from microsoft/Cream. | |||||
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', | |||||
'ir_r1_k5_s2_e4_c40_se0.25', | |||||
'ir_r1_k3_s2_e6_c80_se0.25', | |||||
'ir_r1_k3_s1_e6_c96_se0.25', | |||||
'ir_r1_k5_s2_e6_c192_se0.25'] | |||||
arch_def = [[stem[0]]] + [[choice_block_pool[idx] | |||||
for repeat_times in range(len(arch_list[idx + 1]))] | |||||
for idx in range(len(choice_block_pool))] + [[stem[1]]] | |||||
# generate childnet | |||||
model = gen_childnet( | |||||
arch_list, | |||||
arch_def, | |||||
num_classes=args.num_classes, | |||||
drop_rate=args.dropout_rate, | |||||
global_pool=args.gp) | |||||
# initialize distributed parameters | |||||
distributed = args.num_gpu > 1 | |||||
torch.cuda.set_device(args.local_rank) | |||||
if args.local_rank == 0: | |||||
logger.info( | |||||
'Training on Process {} with {} GPUs.'.format( | |||||
args.local_rank, args.num_gpu)) | |||||
# fix random seeds | |||||
torch.manual_seed(args.trial_id) | |||||
torch.cuda.manual_seed_all(args.trial_id) | |||||
np.random.seed(args.trial_id) | |||||
torch.backends.cudnn.deterministic = True | |||||
torch.backends.cudnn.benchmark = False | |||||
# get parameters and FLOPs of model | |||||
if args.local_rank == 0: | |||||
macs, params = get_model_flops_params(model, input_size=( | |||||
1, 3, args.image_size, args.image_size)) | |||||
logger.info( | |||||
'[Model-{}] Flops: {} Params: {}'.format(args.selection, macs, params)) | |||||
# create optimizer | |||||
model = model.cuda() | |||||
optimizer = create_optimizer(args, model) | |||||
# optionally resume from a checkpoint | |||||
resume_epoch = None | |||||
if args.auto_resume: | |||||
if int(timm.__version__[2]) >= 3: | |||||
resume_epoch = resume_checkpoint(model, args.experiment_dir, optimizer) | |||||
else: | |||||
resume_state, resume_epoch = resume_checkpoint(model, args.experiment_dir) | |||||
optimizer.load_state_dict(resume_state['optimizer']) | |||||
del resume_state | |||||
model_ema = None | |||||
if args.use_ema: | |||||
model_ema = ModelEma( | |||||
model, | |||||
decay=args.ema_decay, | |||||
device='cpu' if args.ema_cpu else '', | |||||
resume=args.experiment_dir if args.auto_resume else None) | |||||
# initialize training parameters | |||||
eval_metric = args.eval_metrics | |||||
best_metric, best_epoch, saver = None, None, None | |||||
if args.local_rank == 0: | |||||
decreasing = True if eval_metric == 'loss' else False | |||||
if int(timm.__version__[2]) >= 3: | |||||
saver = CheckpointSaver(model, optimizer, | |||||
checkpoint_dir=args.checkpoint_dir, | |||||
recovery_dir=args.checkpoint_dir, | |||||
model_ema=model_ema, | |||||
decreasing=decreasing, | |||||
max_history=2) | |||||
else: | |||||
saver = CheckpointSaver( | |||||
checkpoint_dir=args.checkpoint_dir, | |||||
recovery_dir=args.checkpoint_dir, | |||||
decreasing=decreasing, | |||||
max_history=2) | |||||
if distributed: | |||||
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |||||
if args.sync_bn: | |||||
try: | |||||
if HAS_APEX: | |||||
model = convert_syncbn_model(model) | |||||
else: | |||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |||||
if args.local_rank == 0: | |||||
logger.info('Converted model to use Synchronized BatchNorm.') | |||||
except Exception as e: | |||||
if args.local_rank == 0: | |||||
logger.error( | |||||
'Failed to enable Synchronized BatchNorm. ' | |||||
'Install Apex or Torch >= 1.1 with exception {}'.format(e)) | |||||
if HAS_APEX: | |||||
model = DDP(model, delay_allreduce=True) | |||||
else: | |||||
if args.local_rank == 0: | |||||
logger.info( | |||||
"Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") | |||||
# can use device str in Torch >= 1.1 | |||||
model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) | |||||
# imagenet train dataset | |||||
train_dir = os.path.join(args.data_dir, 'train') | |||||
if not os.path.exists(train_dir) and args.local_rank == 0: | |||||
logger.error('Training folder does not exist at: {}'.format(train_dir)) | |||||
exit(1) | |||||
dataset_train = Dataset(train_dir) | |||||
loader_train = create_loader( | |||||
dataset_train, | |||||
input_size=(3, args.image_size, args.image_size), | |||||
batch_size=args.batch_size, | |||||
is_training=True, | |||||
color_jitter=args.color_jitter, | |||||
auto_augment=args.aa, | |||||
num_aug_splits=0, | |||||
crop_pct=DEFAULT_CROP_PCT, | |||||
mean=IMAGENET_DEFAULT_MEAN, | |||||
std=IMAGENET_DEFAULT_STD, | |||||
num_workers=args.workers, | |||||
distributed=distributed, | |||||
collate_fn=None, | |||||
pin_memory=args.pin_mem, | |||||
interpolation='random', | |||||
re_mode=args.re_mode, | |||||
re_prob=args.re_prob | |||||
) | |||||
# imagenet validation dataset | |||||
eval_dir = os.path.join(args.data_dir, 'val') | |||||
if not os.path.exists(eval_dir) and args.local_rank == 0: | |||||
logger.error( | |||||
'Validation folder does not exist at: {}'.format(eval_dir)) | |||||
exit(1) | |||||
dataset_eval = Dataset(eval_dir) | |||||
loader_eval = create_loader( | |||||
dataset_eval, | |||||
input_size=(3, args.image_size, args.image_size), | |||||
batch_size=args.val_batch_mul * args.batch_size, | |||||
is_training=False, | |||||
interpolation=args.interpolation, | |||||
crop_pct=DEFAULT_CROP_PCT, | |||||
mean=IMAGENET_DEFAULT_MEAN, | |||||
std=IMAGENET_DEFAULT_STD, | |||||
num_workers=args.workers, | |||||
distributed=distributed, | |||||
pin_memory=args.pin_mem | |||||
) | |||||
# whether to use label smoothing | |||||
if args.smoothing > 0.: | |||||
train_loss_fn = LabelSmoothingCrossEntropy( | |||||
smoothing=args.smoothing).cuda() | |||||
validate_loss_fn = nn.CrossEntropyLoss().cuda() | |||||
else: | |||||
train_loss_fn = nn.CrossEntropyLoss().cuda() | |||||
validate_loss_fn = train_loss_fn | |||||
# create learning rate scheduler | |||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer) | |||||
start_epoch = resume_epoch if resume_epoch is not None else 0 | |||||
if start_epoch > 0: | |||||
lr_scheduler.step(start_epoch) | |||||
if args.local_rank == 0: | |||||
logger.info('Scheduled epochs: {}'.format(num_epochs)) | |||||
try: | |||||
best_record, best_ep = 0, 0 | |||||
for epoch in range(start_epoch, num_epochs): | |||||
if distributed: | |||||
loader_train.sampler.set_epoch(epoch) | |||||
train_metrics = train_epoch( | |||||
epoch, | |||||
model, | |||||
loader_train, | |||||
optimizer, | |||||
train_loss_fn, | |||||
args, | |||||
lr_scheduler=lr_scheduler, | |||||
saver=saver, | |||||
output_dir=args.checkpoint_dir, | |||||
model_ema=model_ema, | |||||
logger=logger, | |||||
writer=writer, | |||||
local_rank=args.local_rank) | |||||
eval_metrics = validate( | |||||
epoch, | |||||
model, | |||||
loader_eval, | |||||
validate_loss_fn, | |||||
args, | |||||
logger=logger, | |||||
writer=writer, | |||||
local_rank=args.local_rank, | |||||
result_path=args.result_path | |||||
) | |||||
if model_ema is not None and not args.ema_cpu: | |||||
ema_eval_metrics = validate( | |||||
epoch, | |||||
model_ema.ema, | |||||
loader_eval, | |||||
validate_loss_fn, | |||||
args, | |||||
log_suffix='_EMA', | |||||
logger=logger, | |||||
writer=writer, | |||||
local_rank=args.local_rank | |||||
) | |||||
eval_metrics = ema_eval_metrics | |||||
if lr_scheduler is not None: | |||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) | |||||
update_summary(epoch, train_metrics, eval_metrics, os.path.join( | |||||
args.checkpoint_dir, 'summary.csv'), write_header=best_metric is None) | |||||
if saver is not None: | |||||
# save proper checkpoint with eval metric | |||||
save_metric = eval_metrics[eval_metric] | |||||
if int(timm.__version__[2]) >= 3: | |||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) | |||||
else: | |||||
best_metric, best_epoch = saver.save_checkpoint( | |||||
model, optimizer, args, | |||||
epoch=epoch, metric=save_metric) | |||||
if best_record < eval_metrics[eval_metric]: | |||||
best_record = eval_metrics[eval_metric] | |||||
best_ep = epoch | |||||
if args.local_rank == 0: | |||||
logger.info( | |||||
'*** Best metric: {0} (epoch {1})'.format(best_record, best_ep)) | |||||
except KeyboardInterrupt: | |||||
pass | |||||
if best_metric is not None: | |||||
logger.info( | |||||
'*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) | |||||
save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,21 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
class ClassicnasSelector(Selector): | |||||
def __init__(self, *args, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
self.args = args | |||||
def fit(self): | |||||
""" | |||||
only one candatite, function passed | |||||
""" | |||||
pass | |||||
if __name__ == "__main__": | |||||
hpo_selector = ClassicnasSelector() | |||||
hpo_selector.fit() |
@@ -0,0 +1,167 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT License. | |||||
# Written by Hao Du and Houwen Peng | |||||
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com | |||||
import os | |||||
import warnings | |||||
import datetime | |||||
import torch | |||||
import torch.nn as nn | |||||
# from torch.utils.tensorboard import SummaryWriter | |||||
# import timm packages | |||||
from timm.utils import ModelEma | |||||
from timm.models import resume_checkpoint | |||||
from timm.data import Dataset, create_loader | |||||
# import apex as distributed package | |||||
try: | |||||
from apex.parallel import convert_syncbn_model | |||||
from apex.parallel import DistributedDataParallel as DDP | |||||
HAS_APEX = True | |||||
except ImportError as e: | |||||
print(e) | |||||
from torch.nn.parallel import DistributedDataParallel as DDP | |||||
HAS_APEX = False | |||||
# import models and training functions | |||||
from lib.core.test import validate | |||||
from lib.models.structures.childnet import gen_childnet | |||||
from lib.utils.util import parse_config_args, get_logger, get_model_flops_params | |||||
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||||
def main(): | |||||
args, cfg = parse_config_args('child net testing') | |||||
# resolve logging | |||||
output_dir = os.path.join(cfg.SAVE_PATH, | |||||
"{}-{}".format(datetime.date.today().strftime('%m%d'), | |||||
cfg.MODEL)) | |||||
if not os.path.exists(output_dir): | |||||
os.mkdir(output_dir) | |||||
if args.local_rank == 0: | |||||
logger = get_logger(os.path.join(output_dir, 'test.log')) | |||||
writer = None # SummaryWriter(os.path.join(output_dir, 'runs')) | |||||
else: | |||||
writer, logger = None, None | |||||
# retrain model selection | |||||
if cfg.NET.SELECTION == 481: | |||||
arch_list = [ | |||||
[0], [ | |||||
3, 4, 3, 1], [ | |||||
3, 2, 3, 0], [ | |||||
3, 3, 3, 1], [ | |||||
3, 3, 3, 3], [ | |||||
3, 3, 3, 3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 224 | |||||
elif cfg.NET.SELECTION == 43: | |||||
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 96 | |||||
elif cfg.NET.SELECTION == 14: | |||||
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 64 | |||||
elif cfg.NET.SELECTION == 112: | |||||
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 160 | |||||
elif cfg.NET.SELECTION == 287: | |||||
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 224 | |||||
elif cfg.NET.SELECTION == 604: | |||||
arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], | |||||
[3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] | |||||
cfg.DATASET.IMAGE_SIZE = 224 | |||||
else: | |||||
raise ValueError("Model Test Selection is not Supported!") | |||||
# define childnet architecture from arch_list | |||||
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] | |||||
# TODO: this param from NNI is different from microsoft/Cream. | |||||
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', | |||||
'ir_r1_k5_s2_e4_c40_se0.25', | |||||
'ir_r1_k3_s2_e6_c80_se0.25', | |||||
'ir_r1_k3_s1_e6_c96_se0.25', | |||||
'ir_r1_k5_s2_e6_c192_se0.25'] | |||||
arch_def = [[stem[0]]] + [[choice_block_pool[idx] | |||||
for repeat_times in range(len(arch_list[idx + 1]))] | |||||
for idx in range(len(choice_block_pool))] + [[stem[1]]] | |||||
# generate childnet | |||||
model = gen_childnet( | |||||
arch_list, | |||||
arch_def, | |||||
num_classes=cfg.DATASET.NUM_CLASSES, | |||||
drop_rate=cfg.NET.DROPOUT_RATE, | |||||
global_pool=cfg.NET.GP) | |||||
if args.local_rank == 0: | |||||
macs, params = get_model_flops_params(model, input_size=( | |||||
1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) | |||||
logger.info( | |||||
'[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params)) | |||||
# initialize distributed parameters | |||||
torch.cuda.set_device(args.local_rank) | |||||
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |||||
if args.local_rank == 0: | |||||
logger.info( | |||||
"Training on Process {} with {} GPUs.".format( | |||||
args.local_rank, cfg.NUM_GPU)) | |||||
# resume model from checkpoint | |||||
assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH) | |||||
resume_checkpoint(model, cfg.RESUME_PATH) | |||||
model = model.cuda() | |||||
model_ema = None | |||||
if cfg.NET.EMA.USE: | |||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but | |||||
# before SyncBN and DDP wrapper | |||||
model_ema = ModelEma( | |||||
model, | |||||
decay=cfg.NET.EMA.DECAY, | |||||
device='cpu' if cfg.NET.EMA.FORCE_CPU else '', | |||||
resume=cfg.RESUME_PATH) | |||||
# imagenet validation dataset | |||||
eval_dir = os.path.join(cfg.DATA_DIR, 'val') | |||||
if not os.path.exists(eval_dir) and args.local_rank == 0: | |||||
logger.error( | |||||
'Validation folder does not exist at: {}'.format(eval_dir)) | |||||
exit(1) | |||||
dataset_eval = Dataset(eval_dir) | |||||
loader_eval = create_loader( | |||||
dataset_eval, | |||||
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), | |||||
batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, | |||||
is_training=False, | |||||
num_workers=cfg.WORKERS, | |||||
distributed=True, | |||||
pin_memory=cfg.DATASET.PIN_MEM, | |||||
crop_pct=DEFAULT_CROP_PCT, | |||||
mean=IMAGENET_DEFAULT_MEAN, | |||||
std=IMAGENET_DEFAULT_STD | |||||
) | |||||
# only test accuracy of model-EMA | |||||
validate_loss_fn = nn.CrossEntropyLoss().cuda() | |||||
validate(0, model, loader_eval, validate_loss_fn, cfg, | |||||
log_suffix='_EMA', logger=logger, | |||||
writer=writer, local_rank=args.local_rank) | |||||
if cfg.NET.EMA.USE: | |||||
validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg, | |||||
log_suffix='_EMA', logger=logger, | |||||
writer=writer, local_rank=args.local_rank) | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,312 @@ | |||||
# https://github.com/microsoft/nni/blob/v2.0/examples/nas/cream/train.py | |||||
import sys | |||||
sys.path.append('../..') | |||||
import os | |||||
import sys | |||||
import time | |||||
import json | |||||
import torch | |||||
import numpy as np | |||||
import torch.nn as nn | |||||
from argparse import ArgumentParser | |||||
# import timm packages | |||||
from timm.loss import LabelSmoothingCrossEntropy | |||||
from timm.data import Dataset, create_loader | |||||
from timm.models import resume_checkpoint | |||||
# import apex as distributed package | |||||
# try: | |||||
# from apex.parallel import DistributedDataParallel as DDP | |||||
# from apex.parallel import convert_syncbn_model | |||||
# | |||||
# USE_APEX = True | |||||
# except ImportError as e: | |||||
# print(e) | |||||
# from torch.nn.parallel import DistributedDataParallel as DDP | |||||
# | |||||
# USE_APEX = False | |||||
# import models and training functions | |||||
from lib.utils.flops_table import FlopsEst | |||||
from lib.models.structures.supernet import gen_supernet | |||||
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN | |||||
from lib.utils.util import get_logger, \ | |||||
create_optimizer_supernet, create_supernet_scheduler | |||||
from pytorch.utils import mkdirs, str2bool | |||||
from pytorch.callbacks import LRSchedulerCallback | |||||
from pytorch.callbacks import ModelCheckpoint | |||||
from algorithms import CreamSupernetTrainer | |||||
from algorithms import RandomMutator | |||||
def parse_args(): | |||||
"""See lib.utils.config""" | |||||
parser = ArgumentParser() | |||||
# path | |||||
parser.add_argument("--checkpoint_dir", type=str, default='') | |||||
parser.add_argument("--data_dir", type=str, default='./data') | |||||
parser.add_argument("--experiment_dir", type=str, default='./') | |||||
parser.add_argument("--model_name", type=str, default='trainer') | |||||
parser.add_argument("--log_path", type=str, default='output/log') | |||||
parser.add_argument("--result_path", type=str, default='output/result.json') | |||||
parser.add_argument("--search_space_path", type=str, default='output/search_space.json') | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='output/selected_space.json') | |||||
# int | |||||
parser.add_argument("--acc_gap", type=int, default=5) | |||||
parser.add_argument("--batch_size", type=int, default=1) | |||||
parser.add_argument("--epochs", type=int, default=200) | |||||
parser.add_argument("--flops_minimum", type=int, default=0) | |||||
parser.add_argument("--flops_maximum", type=int, default=200) | |||||
parser.add_argument("--image_size", type=int, default=224) | |||||
parser.add_argument("--local_rank", type=int, default=0) | |||||
parser.add_argument("--log_interval", type=int, default=50) | |||||
parser.add_argument("--meta_sta_epoch", type=int, default=20) | |||||
parser.add_argument("--num_classes", type=int, default=1000) | |||||
parser.add_argument("--num_gpu", type=int, default=1) | |||||
parser.add_argument("--pool_size", type=int, default=10) | |||||
parser.add_argument("--trial_id", type=int, default=42) | |||||
parser.add_argument("--slice_num", type=int, default=4) | |||||
parser.add_argument("--tta", type=int, default=0) | |||||
parser.add_argument("--update_iter", type=int, default=1300) | |||||
parser.add_argument("--workers", type=int, default=4) | |||||
# float | |||||
parser.add_argument("--color_jitter", type=float, default=0.4) | |||||
parser.add_argument("--dropout_rate", type=float, default=0.0) | |||||
parser.add_argument("--lr", type=float, default=1e-2) | |||||
parser.add_argument("--meta_lr", type=float, default=1e-4) | |||||
parser.add_argument("--opt_eps", type=float, default=1e-2) | |||||
parser.add_argument("--re_prob", type=float, default=0.2) | |||||
parser.add_argument("--momentum", type=float, default=0.9) | |||||
parser.add_argument("--smoothing", type=float, default=0.1) | |||||
parser.add_argument("--weight_decay", type=float, default=1e-4) | |||||
# bool | |||||
parser.add_argument("--auto_resume", type=str2bool, default='False') | |||||
parser.add_argument("--dil_conv", type=str2bool, default='False') | |||||
parser.add_argument("--resunit", type=str2bool, default='False') | |||||
parser.add_argument("--sync_bn", type=str2bool, default='False') | |||||
parser.add_argument("--verbose", type=str2bool, default='False') | |||||
# str | |||||
# gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"] | |||||
parser.add_argument("--gp", type=str, default='avg') | |||||
parser.add_argument("--interpolation", type=str, default='bilinear') | |||||
parser.add_argument("--opt", type=str, default='sgd') | |||||
parser.add_argument("--pick_method", type=str, default='meta') | |||||
parser.add_argument("--re_mode", type=str, default='pixel') | |||||
args = parser.parse_args() | |||||
args.sync_bn = False | |||||
args.verbose = False | |||||
args.data_dir = args.data_dir + "/imagenet" | |||||
return args | |||||
def main(): | |||||
args = parse_args() | |||||
mkdirs(args.experiment_dir, | |||||
args.best_selected_space_path, | |||||
args.search_space_path, | |||||
args.result_path, | |||||
args.log_path) | |||||
with open(args.result_path, "w") as ss_file: | |||||
ss_file.write('') | |||||
# resolve logging | |||||
if len(args.checkpoint_dir > 1): | |||||
mkdirs(args.checkpoint_dir + "/") | |||||
args.checkpoint_dir = os.path.join( | |||||
args.checkpoint_dir, | |||||
"{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) | |||||
) | |||||
if not os.path.exists(args.checkpoint_dir): | |||||
os.mkdir(args.checkpoint_dir) | |||||
if args.local_rank == 0: | |||||
logger = get_logger(args.log_path) | |||||
else: | |||||
logger = None | |||||
# initialize distributed parameters | |||||
torch.cuda.set_device(args.local_rank) | |||||
# torch.distributed.init_process_group(backend='nccl', init_method='env://') | |||||
if args.local_rank == 0: | |||||
logger.info( | |||||
'Training on Process %d with %d GPUs.', | |||||
args.local_rank, args.num_gpu) | |||||
# fix random seeds | |||||
torch.manual_seed(args.trial_id) | |||||
torch.cuda.manual_seed_all(args.trial_id) | |||||
np.random.seed(args.trial_id) | |||||
torch.backends.cudnn.deterministic = True | |||||
torch.backends.cudnn.benchmark = False | |||||
# generate supernet and optimizer | |||||
model, sta_num, resolution, search_space = gen_supernet( | |||||
flops_minimum=args.flops_minimum, | |||||
flops_maximum=args.flops_maximum, | |||||
num_classes=args.num_classes, | |||||
drop_rate=args.dropout_rate, | |||||
global_pool=args.gp, | |||||
resunit=args.resunit, | |||||
dil_conv=args.dil_conv, | |||||
slice=args.slice_num, | |||||
verbose=args.verbose, | |||||
logger=logger) | |||||
optimizer = create_optimizer_supernet(args, model) | |||||
# number of choice blocks in supernet | |||||
choice_num = len(model.blocks[7]) | |||||
if args.local_rank == 0: | |||||
logger.info('Supernet created, param count: %d', ( | |||||
sum([m.numel() for m in model.parameters()]))) | |||||
logger.info('resolution: %d', resolution) | |||||
logger.info('choice number: %d', choice_num) | |||||
with open(args.search_space_path, "w") as f: | |||||
print("dump search space.") | |||||
json.dump({'search_space': search_space}, f) | |||||
# initialize flops look-up table | |||||
model_est = FlopsEst(model) | |||||
flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed | |||||
model = model.cuda() | |||||
# convert model to distributed mode | |||||
if args.sync_bn: | |||||
try: | |||||
# if USE_APEX: | |||||
# model = convert_syncbn_model(model) | |||||
# else: | |||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |||||
if args.local_rank == 0: | |||||
logger.info('Converted model to use Synchronized BatchNorm.') | |||||
except Exception as exception: | |||||
logger.info( | |||||
'Failed to enable Synchronized BatchNorm. ' | |||||
'Install Apex or Torch >= 1.1 with Exception %s', exception) | |||||
# if USE_APEX: | |||||
# model = DDP(model, delay_allreduce=True) | |||||
# else: | |||||
# if args.local_rank == 0: | |||||
# logger.info( | |||||
# "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") | |||||
# # can use device str in Torch >= 1.1 | |||||
# model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) | |||||
# optionally resume from a checkpoint | |||||
resume_epoch = None | |||||
if False: # args.auto_resume: | |||||
checkpoint = torch.load(args.experiment_dir) | |||||
model.load_state_dict(checkpoint['child_model_state_dict']) | |||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |||||
resume_epoch = checkpoint['epoch'] | |||||
# create learning rate scheduler | |||||
lr_scheduler, num_epochs = create_supernet_scheduler(optimizer, args.epochs, args.num_gpu, | |||||
args.batch_size, args.lr) | |||||
start_epoch = resume_epoch if resume_epoch is not None else 0 | |||||
if start_epoch > 0: | |||||
lr_scheduler.step(start_epoch) | |||||
if args.local_rank == 0: | |||||
logger.info('Scheduled epochs: %d', num_epochs) | |||||
# imagenet train dataset | |||||
train_dir = os.path.join(args.data_dir, 'train') | |||||
if not os.path.exists(train_dir): | |||||
logger.info('Training folder does not exist at: %s', train_dir) | |||||
sys.exit() | |||||
dataset_train = Dataset(train_dir) | |||||
loader_train = create_loader( | |||||
dataset_train, | |||||
input_size=(3, args.image_size, args.image_size), | |||||
batch_size=args.batch_size, | |||||
is_training=True, | |||||
use_prefetcher=True, | |||||
re_prob=args.re_prob, | |||||
re_mode=args.re_mode, | |||||
color_jitter=args.color_jitter, | |||||
interpolation='random', | |||||
num_workers=args.workers, | |||||
distributed=False, | |||||
collate_fn=None, | |||||
crop_pct=DEFAULT_CROP_PCT, | |||||
mean=IMAGENET_DEFAULT_MEAN, | |||||
std=IMAGENET_DEFAULT_STD | |||||
) | |||||
# imagenet validation dataset | |||||
eval_dir = os.path.join(args.data_dir, 'val') | |||||
if not os.path.isdir(eval_dir): | |||||
logger.info('Validation folder does not exist at: %s', eval_dir) | |||||
sys.exit() | |||||
dataset_eval = Dataset(eval_dir) | |||||
loader_eval = create_loader( | |||||
dataset_eval, | |||||
input_size=(3, args.image_size, args.image_size), | |||||
batch_size=4 * args.batch_size, | |||||
is_training=False, | |||||
use_prefetcher=True, | |||||
num_workers=args.workers, | |||||
distributed=False, | |||||
crop_pct=DEFAULT_CROP_PCT, | |||||
mean=IMAGENET_DEFAULT_MEAN, | |||||
std=IMAGENET_DEFAULT_STD, | |||||
interpolation=args.interpolation | |||||
) | |||||
# whether to use label smoothing | |||||
if args.smoothing > 0.: | |||||
train_loss_fn = LabelSmoothingCrossEntropy( | |||||
smoothing=args.smoothing).cuda() | |||||
validate_loss_fn = nn.CrossEntropyLoss().cuda() | |||||
else: | |||||
train_loss_fn = nn.CrossEntropyLoss().cuda() | |||||
validate_loss_fn = train_loss_fn | |||||
mutator = RandomMutator(model) | |||||
_callbacks = [LRSchedulerCallback(lr_scheduler)] | |||||
if len(args.checkpoint_dir) > 1: | |||||
_callbacks.append(ModelCheckpoint(checkpoint_dir)) | |||||
trainer = CreamSupernetTrainer(args.best_selected_space_path, model, train_loss_fn, | |||||
validate_loss_fn, | |||||
optimizer, num_epochs, loader_train, loader_eval, | |||||
result_path=args.result_path, | |||||
mutator=mutator, | |||||
batch_size=args.batch_size, | |||||
log_frequency=args.log_interval, | |||||
meta_sta_epoch=args.meta_sta_epoch, | |||||
update_iter=args.update_iter, | |||||
slices=args.slice_num, | |||||
pool_size=args.pool_size, | |||||
pick_method=args.pick_method, | |||||
choice_num=choice_num, | |||||
sta_num=sta_num, | |||||
acc_gap=args.acc_gap, | |||||
flops_dict=flops_dict, | |||||
flops_fixed=flops_fixed, | |||||
local_rank=args.local_rank, | |||||
callbacks=_callbacks) | |||||
trainer.train() | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,2 @@ | |||||
from pytorch.darts.dartstrainer import DartsTrainer | |||||
from pytorch.darts.dartsmutator import DartsMutator |
@@ -0,0 +1,205 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
import time | |||||
from argparse import ArgumentParser | |||||
import torch | |||||
import torch.nn as nn | |||||
# from torch.utils.tensorboard import SummaryWriter | |||||
import datasets | |||||
import utils | |||||
from model import CNN | |||||
from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter | |||||
from pytorch.fixed import apply_fixed_architecture | |||||
from pytorch.retrainer import Retrainer | |||||
logger = logging.getLogger(__name__) | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
# writer = SummaryWriter() | |||||
class DartsRetrainer(Retrainer): | |||||
def __init__(self, aux_weight, grad_clip, epochs, log_frequency): | |||||
self.aux_weight = aux_weight | |||||
self.grad_clip = grad_clip | |||||
self.epochs = epochs | |||||
self.log_frequency = log_frequency | |||||
def train(self, train_loader, model, optimizer, criterion, epoch): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
cur_step = epoch * len(train_loader) | |||||
cur_lr = optimizer.param_groups[0]["lr"] | |||||
logger.info("Epoch %d LR %.6f", epoch, cur_lr) | |||||
# writer.add_scalar("lr", cur_lr, global_step=cur_step) | |||||
model.train() | |||||
for step, (x, y) in enumerate(train_loader): | |||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = x.size(0) | |||||
optimizer.zero_grad() | |||||
logits, aux_logits = model(x) | |||||
loss = criterion(logits, y) | |||||
if self.aux_weight > 0.: | |||||
loss += self.aux_weight * criterion(aux_logits, y) | |||||
loss.backward() | |||||
# gradient clipping | |||||
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) | |||||
optimizer.step() | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
# writer.add_scalar("loss/train", loss.item(), global_step=cur_step) | |||||
# writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) | |||||
# writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) | |||||
if step % self.log_frequency == 0 or step == len(train_loader) - 1: | |||||
logger.info( | |||||
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
cur_step += 1 | |||||
logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
def validate(self, valid_loader, model, criterion, epoch, cur_step): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
for step, (X, y) in enumerate(valid_loader): | |||||
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = X.size(0) | |||||
logits = model(X) | |||||
loss = criterion(logits, y) | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
if step % self.log_frequency == 0 or step == len(valid_loader) - 1: | |||||
logger.info( | |||||
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
# writer.add_scalar("loss/test", losses.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) | |||||
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
return top1.avg | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("DARTS retrain") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./data/', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='.0/result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='.0/log', help="log for info") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument("--best_checkpoint_dir", type=str, | |||||
default='./', help="default name is best_checkpoint_epoch{}.pth") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--layers", default=20, type=int) | |||||
parser.add_argument("--lr", default=0.025, type=float) | |||||
parser.add_argument("--batch_size", default=128, type=int) | |||||
parser.add_argument("--log_frequency", default=10, type=int) | |||||
parser.add_argument("--epochs", default=5, type=int) | |||||
parser.add_argument("--aux_weight", default=0.4, type=float) | |||||
parser.add_argument("--drop_path_prob", default=0.2, type=float) | |||||
parser.add_argument("--workers", default=4, type=int) | |||||
parser.add_argument("--channels", default=36, type=int) | |||||
parser.add_argument("--grad_clip", default=5., type=float) | |||||
parser.add_argument("--class_num", default=10, type=int, help="cifar10") | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) | |||||
init_logger(args.log_path) | |||||
logger.info(args) | |||||
set_seed(args.trial_id) | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) | |||||
model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True) | |||||
apply_fixed_architecture(model, args.best_selected_space_path) | |||||
criterion = nn.CrossEntropyLoss() | |||||
model.to(device) | |||||
criterion.to(device) | |||||
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) | |||||
train_loader = torch.utils.data.DataLoader(dataset_train, | |||||
batch_size=args.batch_size, | |||||
shuffle=True, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
valid_loader = torch.utils.data.DataLoader(dataset_valid, | |||||
batch_size=args.batch_size, | |||||
shuffle=False, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
retrainer = DartsRetrainer(aux_weight=args.aux_weight, | |||||
grad_clip=args.grad_clip, | |||||
epochs=args.epochs, | |||||
log_frequency = args.log_frequency) | |||||
# result = {"Accuracy": [], "Cost_time": ''} | |||||
best_top1 = 0. | |||||
start_time = time.time() | |||||
with open(args.result_path, "w") as file: | |||||
file.write('') | |||||
for epoch in range(args.epochs): | |||||
drop_prob = args.drop_path_prob * epoch / args.epochs | |||||
model.drop_path_prob(drop_prob) | |||||
# training | |||||
retrainer.train(train_loader, model, optimizer, criterion, epoch) | |||||
# validation | |||||
cur_step = (epoch + 1) * len(train_loader) | |||||
top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) | |||||
# 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} | |||||
logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') | |||||
# result["Accuracy"].append(top1) | |||||
best_top1 = max(best_top1, top1) | |||||
lr_scheduler.step() | |||||
logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) | |||||
cost_time = time.time() - start_time | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) | |||||
# result["Cost_time"] = str(cost_time) + ' s' | |||||
# dump_global_result(args.result_path, result) | |||||
save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) | |||||
logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) | |||||
@@ -0,0 +1,21 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
from argparse import ArgumentParser | |||||
class DartsSelector(Selector): | |||||
def __init__(self, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
def fit(self): | |||||
pass | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("DARTS select") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
args = parser.parse_args() | |||||
darts_selector = DartsSelector(True) | |||||
darts_selector.fit() |
@@ -0,0 +1,85 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import time | |||||
from argparse import ArgumentParser | |||||
import torch | |||||
import torch.nn as nn | |||||
import datasets | |||||
from model import CNN | |||||
from utils import accuracy | |||||
from dartstrainer import DartsTrainer | |||||
from pytorch.utils import * | |||||
from pytorch.callbacks import BestArchitectureCheckpoint, LRSchedulerCallback | |||||
logger = logging.getLogger(__name__) | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("DARTS train") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='../data/', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='.0/result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='.0/log', help="log for info") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search space of PDARTS") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--layers", default=8, type=int) | |||||
parser.add_argument("--batch_size", default=64, type=int) | |||||
parser.add_argument("--log_frequency", default=10, type=int) | |||||
parser.add_argument("--epochs", default=5, type=int) | |||||
parser.add_argument("--channels", default=16, type=int) | |||||
parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights') | |||||
parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture') | |||||
parser.add_argument("--unrolled", default=False, action="store_true") | |||||
parser.add_argument("--visualization", default=False, action="store_true") | |||||
parser.add_argument("--class_num", default=10, type=int, help="cifar10") | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) | |||||
init_logger(args.log_path, "info") | |||||
logger.info(args) | |||||
set_seed(args.trial_id) | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", root=args.data_dir) | |||||
model = CNN(32, 3, args.channels, args.class_num, args.layers) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optim = torch.optim.SGD(model.parameters(), args.model_lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) | |||||
trainer = DartsTrainer(model, | |||||
loss=criterion, | |||||
metrics=lambda output, target: accuracy(output, target, topk=(1,)), | |||||
optimizer=optim, | |||||
num_epochs=args.epochs, | |||||
dataset_train=dataset_train, | |||||
dataset_valid=dataset_valid, | |||||
search_space_path = args.search_space_path, | |||||
batch_size=args.batch_size, | |||||
log_frequency=args.log_frequency, | |||||
result_path=args.result_path, | |||||
unrolled=args.unrolled, | |||||
arch_lr=args.arch_lr, | |||||
callbacks=[LRSchedulerCallback(lr_scheduler), BestArchitectureCheckpoint(args.best_selected_space_path, args.epochs)]) | |||||
if args.visualization: | |||||
trainer.enable_visualization() | |||||
t1 = time.time() | |||||
trainer.train() | |||||
# res_json = trainer.result | |||||
cost_time = time.time() - t1 | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) | |||||
# res_json["Cost_time"] = str(cost_time) + ' s' | |||||
# dump_global_result(args.result_path, res_json) |
@@ -0,0 +1,134 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from collections import OrderedDict | |||||
from pytorch.mutator import Mutator | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
_logger = logging.getLogger(__name__) | |||||
class DartsMutator(Mutator): | |||||
""" | |||||
Connects the model in a DARTS (differentiable) way. | |||||
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no | |||||
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false`` | |||||
(not chosen). | |||||
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based | |||||
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice | |||||
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0. | |||||
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the | |||||
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will | |||||
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully. | |||||
Attributes | |||||
---------- | |||||
choices: ParameterDict | |||||
dict that maps keys of LayerChoices to weighted-connection float tensors. | |||||
""" | |||||
def __init__(self, model): | |||||
super().__init__(model) | |||||
self.choices = nn.ParameterDict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) | |||||
def device(self): | |||||
for v in self.choices.values(): | |||||
return v.device | |||||
def sample_search(self): | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] | |||||
elif isinstance(mutable, InputChoice): | |||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) | |||||
return result | |||||
def sample_final(self): | |||||
result = dict() | |||||
edges_max = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) | |||||
edges_max[mutable.key] = max_val | |||||
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, InputChoice): | |||||
if mutable.n_chosen is not None: | |||||
weights = [] | |||||
for src_key in mutable.choose_from: | |||||
if src_key not in edges_max: | |||||
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key) | |||||
weights.append(edges_max.get(src_key, 0.)) | |||||
weights = torch.tensor(weights) # pylint: disable=not-callable | |||||
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen) | |||||
selected_multihot = [] | |||||
for i, src_key in enumerate(mutable.choose_from): | |||||
if i not in topk_edge_indices and src_key in result: | |||||
# If an edge is never selected, there is no need to calculate any op on this edge. | |||||
# This is to eliminate redundant calculation. | |||||
result[src_key] = torch.zeros_like(result[src_key]) | |||||
selected_multihot.append(i in topk_edge_indices) | |||||
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable | |||||
else: | |||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable | |||||
return result | |||||
def _generate_search_space(self): | |||||
""" | |||||
Generate search space from mutables. | |||||
Here is the search space format: | |||||
:: | |||||
{ key_name: {"_type": "layer_choice", | |||||
"_value": ["conv1", "conv2"]} } | |||||
{ key_name: {"_type": "input_choice", | |||||
"_value": {"candidates": ["in1", "in2"], | |||||
"n_chosen": 1}} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the generated search space | |||||
""" | |||||
res = OrderedDict() | |||||
res["op_list"] = OrderedDict() | |||||
res["search_space"] = OrderedDict() | |||||
# res["normal_cell"] = OrderedDict(), | |||||
# res["reduction_cell"] = OrderedDict() | |||||
keys = [] | |||||
for mutable in self.mutables: | |||||
# for now we only generate flattened search space | |||||
if (len(res["search_space"])) >= 36: | |||||
break | |||||
if isinstance(mutable, LayerChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
val = mutable.names | |||||
if not res["op_list"]: | |||||
res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} | |||||
# node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][key] = "op_list" | |||||
keys.append(key) | |||||
elif isinstance(mutable, InputChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
# node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][key] = {"_type": "input_choice", | |||||
"_value": {"candidates": mutable.choose_from, | |||||
"n_chosen": mutable.n_chosen}} | |||||
keys.append(key) | |||||
else: | |||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) | |||||
return res |
@@ -0,0 +1,227 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import copy | |||||
import logging | |||||
import torch | |||||
import torch.nn as nn | |||||
from pytorch.trainer import Trainer | |||||
from pytorch.utils import AverageMeterGroup, dump_global_result | |||||
from pytorch.darts.dartsmutator import DartsMutator | |||||
import json | |||||
logger = logging.getLogger(__name__) | |||||
class DartsTrainer(Trainer): | |||||
""" | |||||
DARTS trainer. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to be trained. | |||||
loss : callable | |||||
Receives logits and ground truth label, return a loss tensor. | |||||
metrics : callable | |||||
Receives logits and ground truth label, return a dict of metrics. | |||||
optimizer : Optimizer | |||||
The optimizer used for optimizing the model. | |||||
num_epochs : int | |||||
Number of epochs planned for training. | |||||
dataset_train : Dataset | |||||
Dataset for training. Will be split for training weights and architecture weights. | |||||
dataset_valid : Dataset | |||||
Dataset for testing. | |||||
mutator : DartsMutator | |||||
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator. | |||||
batch_size : int | |||||
Batch size. | |||||
workers : int | |||||
Workers for data loading. | |||||
device : torch.device | |||||
``torch.device("cpu")`` or ``torch.device("cuda")``. | |||||
log_frequency : int | |||||
Step count per logging. | |||||
callbacks : list of Callback | |||||
list of callbacks to trigger at events. | |||||
arch_lr : float | |||||
Learning rate of architecture parameters. | |||||
unrolled : float | |||||
``True`` if using second order optimization, else first order optimization. | |||||
""" | |||||
def __init__(self, model, loss, metrics, | |||||
optimizer, num_epochs, dataset_train, dataset_valid, search_space_path, result_path, num_pre_epochs=0, | |||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, | |||||
callbacks=None, arch_lr=3.0E-4, unrolled=False): | |||||
super().__init__(model, mutator if mutator is not None else DartsMutator(model), | |||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, | |||||
batch_size, workers, device, log_frequency, callbacks) | |||||
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arch_lr, betas=(0.5, 0.999), weight_decay=1.0E-3) | |||||
self.unrolled = unrolled | |||||
self.num_pre_epoches = num_pre_epochs | |||||
self.result_path = result_path | |||||
with open(self.result_path, "w") as file: | |||||
file.write('') | |||||
n_train = len(self.dataset_train) | |||||
split = n_train // 2 | |||||
indices = list(range(n_train)) | |||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) | |||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) | |||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=batch_size, | |||||
sampler=train_sampler, | |||||
num_workers=workers) | |||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=batch_size, | |||||
sampler=valid_sampler, | |||||
num_workers=workers) | |||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, | |||||
batch_size=batch_size, | |||||
num_workers=workers) | |||||
if search_space_path is not None: | |||||
dump_global_result(search_space_path, self.mutator._generate_search_space()) | |||||
# self.result = {"Accuracy": []} | |||||
def train_one_epoch(self, epoch): | |||||
self.model.train() | |||||
self.mutator.train() | |||||
meters = AverageMeterGroup() | |||||
# t1 = time() | |||||
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): | |||||
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) | |||||
val_X, val_y = val_X.to(self.device), val_y.to(self.device) | |||||
if epoch >= self.num_pre_epoches: | |||||
# phase 1. architecture step | |||||
self.ctrl_optim.zero_grad() | |||||
if self.unrolled: | |||||
self._unrolled_backward(trn_X, trn_y, val_X, val_y) | |||||
else: | |||||
self._backward(val_X, val_y) | |||||
self.ctrl_optim.step() | |||||
# phase 2: child network step | |||||
self.optimizer.zero_grad() | |||||
logits, loss = self._logits_and_loss(trn_X, trn_y) | |||||
loss.backward() | |||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping | |||||
self.optimizer.step() | |||||
metrics = self.metrics(logits, trn_y) | |||||
metrics["loss"] = loss.item() | |||||
meters.update(metrics) | |||||
if self.log_frequency is not None and step % self.log_frequency == 0: | |||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, | |||||
self.num_epochs, step + 1, len(self.train_loader), meters) | |||||
def validate_one_epoch(self, epoch, log_print=True): | |||||
self.model.eval() | |||||
self.mutator.eval() | |||||
meters = AverageMeterGroup() | |||||
with torch.no_grad(): | |||||
self.mutator.reset() | |||||
for step, (X, y) in enumerate(self.test_loader): | |||||
X, y = X.to(self.device), y.to(self.device) | |||||
logits = self.model(X) | |||||
metrics = self.metrics(logits, y) | |||||
meters.update(metrics) | |||||
if self.log_frequency is not None and step % self.log_frequency == 0: | |||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, | |||||
self.num_epochs, step + 1, len(self.test_loader), meters) | |||||
if log_print: | |||||
# 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} | |||||
logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) | |||||
with open(self.result_path, "a") as file: | |||||
file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + '\n') | |||||
# self.result["Accuracy"].append(meters.get_last_acc()) | |||||
def _logits_and_loss(self, X, y): | |||||
self.mutator.reset() | |||||
logits = self.model(X) | |||||
loss = self.loss(logits, y) | |||||
# self._write_graph_status() | |||||
return logits, loss | |||||
def _backward(self, val_X, val_y): | |||||
""" | |||||
Simple backward with gradient descent | |||||
""" | |||||
_, loss = self._logits_and_loss(val_X, val_y) | |||||
loss.backward() | |||||
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y): | |||||
""" | |||||
Compute unrolled loss and backward its gradients | |||||
""" | |||||
backup_params = copy.deepcopy(tuple(self.model.parameters())) | |||||
# do virtual step on training data | |||||
lr = self.optimizer.param_groups[0]["lr"] | |||||
momentum = self.optimizer.param_groups[0]["momentum"] | |||||
weight_decay = self.optimizer.param_groups[0]["weight_decay"] | |||||
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay) | |||||
# calculate unrolled loss on validation data | |||||
# keep gradients for model here for compute hessian | |||||
_, loss = self._logits_and_loss(val_X, val_y) | |||||
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters()) | |||||
w_grads = torch.autograd.grad(loss, w_model + w_ctrl) | |||||
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):] | |||||
# compute hessian and final gradients | |||||
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y) | |||||
with torch.no_grad(): | |||||
for param, d, h in zip(w_ctrl, d_ctrl, hessian): | |||||
# gradient = dalpha - lr * hessian | |||||
param.grad = d - lr * h | |||||
# restore weights | |||||
self._restore_weights(backup_params) | |||||
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay): | |||||
""" | |||||
Compute unrolled weights w` | |||||
""" | |||||
# don't need zero_grad, using autograd to calculate gradients | |||||
_, loss = self._logits_and_loss(X, y) | |||||
gradients = torch.autograd.grad(loss, self.model.parameters()) | |||||
with torch.no_grad(): | |||||
for w, g in zip(self.model.parameters(), gradients): | |||||
m = self.optimizer.state[w].get("momentum_buffer", 0.) | |||||
w = w - lr * (momentum * m + g + weight_decay * w) | |||||
def _restore_weights(self, backup_params): | |||||
with torch.no_grad(): | |||||
for param, backup in zip(self.model.parameters(), backup_params): | |||||
param.copy_(backup) | |||||
def _compute_hessian(self, backup_params, dw, trn_X, trn_y): | |||||
""" | |||||
dw = dw` { L_val(w`, alpha) } | |||||
w+ = w + eps * dw | |||||
w- = w - eps * dw | |||||
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps) | |||||
eps = 0.01 / ||dw|| | |||||
""" | |||||
self._restore_weights(backup_params) | |||||
norm = torch.cat([w.view(-1) for w in dw]).norm() | |||||
eps = 0.01 / norm | |||||
if norm < 1E-8: | |||||
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item()) | |||||
dalphas = [] | |||||
for e in [eps, -2. * eps]: | |||||
# w+ = w + eps*dw`, w- = w - eps*dw` | |||||
with torch.no_grad(): | |||||
for p, d in zip(self.model.parameters(), dw): | |||||
p += e * d | |||||
_, loss = self._logits_and_loss(trn_X, trn_y) | |||||
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) | |||||
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } | |||||
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)] | |||||
return hessian |
@@ -0,0 +1,56 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import numpy as np | |||||
import torch | |||||
from torchvision import transforms | |||||
from torchvision.datasets import CIFAR10 | |||||
class Cutout(object): | |||||
def __init__(self, length): | |||||
self.length = length | |||||
def __call__(self, img): | |||||
h, w = img.size(1), img.size(2) | |||||
mask = np.ones((h, w), np.float32) | |||||
y = np.random.randint(h) | |||||
x = np.random.randint(w) | |||||
y1 = np.clip(y - self.length // 2, 0, h) | |||||
y2 = np.clip(y + self.length // 2, 0, h) | |||||
x1 = np.clip(x - self.length // 2, 0, w) | |||||
x2 = np.clip(x + self.length // 2, 0, w) | |||||
mask[y1: y2, x1: x2] = 0. | |||||
mask = torch.from_numpy(mask) | |||||
mask = mask.expand_as(img) | |||||
img *= mask | |||||
return img | |||||
def get_dataset(cls, cutout_length=0, root=None): | |||||
MEAN = [0.49139968, 0.48215827, 0.44653124] | |||||
STD = [0.24703233, 0.24348505, 0.26158768] | |||||
transf = [ | |||||
transforms.RandomCrop(32, padding=4), | |||||
transforms.RandomHorizontalFlip() | |||||
] | |||||
normalize = [ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize(MEAN, STD) | |||||
] | |||||
cutout = [] | |||||
if cutout_length > 0: | |||||
cutout.append(Cutout(cutout_length)) | |||||
train_transform = transforms.Compose(transf + normalize + cutout) | |||||
valid_transform = transforms.Compose(normalize) | |||||
if cls == "cifar10": | |||||
dataset_train = CIFAR10(root=root, train=True, download=True, transform=train_transform) | |||||
dataset_valid = CIFAR10(root=root, train=False, download=True, transform=valid_transform) | |||||
else: | |||||
raise NotImplementedError | |||||
return dataset_train, dataset_valid |
@@ -0,0 +1,160 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
import ops | |||||
import pytorch.mutables as mutables | |||||
class AuxiliaryHead(nn.Module): | |||||
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ | |||||
def __init__(self, input_size, C, n_classes): | |||||
""" assuming input size 7x7 or 8x8 """ | |||||
assert input_size in [7, 8] | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(inplace=True), | |||||
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out | |||||
nn.Conv2d(C, 128, kernel_size=1, bias=False), | |||||
nn.BatchNorm2d(128), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out | |||||
nn.BatchNorm2d(768), | |||||
nn.ReLU(inplace=True) | |||||
) | |||||
self.linear = nn.Linear(768, n_classes) | |||||
def forward(self, x): | |||||
out = self.net(x) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
return logits | |||||
class Node(nn.Module): | |||||
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): | |||||
super().__init__() | |||||
self.ops = nn.ModuleList() | |||||
choice_keys = [] | |||||
for i in range(num_prev_nodes): | |||||
stride = 2 if i < num_downsample_connect else 1 | |||||
choice_keys.append("{}_p{}".format(node_id, i)) | |||||
self.ops.append( | |||||
mutables.LayerChoice(OrderedDict([ | |||||
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), | |||||
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), | |||||
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)), | |||||
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)), | |||||
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), | |||||
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), | |||||
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) | |||||
]), key=choice_keys[-1])) | |||||
self.drop_path = ops.DropPath() | |||||
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) | |||||
def forward(self, prev_nodes): | |||||
assert len(self.ops) == len(prev_nodes) | |||||
out = [op(node) for op, node in zip(self.ops, prev_nodes)] | |||||
out = [self.drop_path(o) if o is not None else None for o in out] | |||||
return self.input_switch(out) | |||||
class Cell(nn.Module): | |||||
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): | |||||
super().__init__() | |||||
self.reduction = reduction | |||||
self.n_nodes = n_nodes | |||||
# If previous cell is reduction cell, current input size does not match with | |||||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. | |||||
if reduction_p: | |||||
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) | |||||
else: | |||||
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) | |||||
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) | |||||
# generate dag | |||||
self.mutable_ops = nn.ModuleList() | |||||
for depth in range(2, self.n_nodes + 2): | |||||
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), | |||||
depth, channels, 2 if reduction else 0)) | |||||
def forward(self, s0, s1): | |||||
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. | |||||
tensors = [self.preproc0(s0), self.preproc1(s1)] | |||||
for node in self.mutable_ops: | |||||
cur_tensor = node(tensors) | |||||
tensors.append(cur_tensor) | |||||
output = torch.cat(tensors[2:], dim=1) | |||||
return output | |||||
class CNN(nn.Module): | |||||
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, | |||||
stem_multiplier=3, auxiliary=False): | |||||
super().__init__() | |||||
self.in_channels = in_channels | |||||
self.channels = channels | |||||
self.n_classes = n_classes | |||||
self.n_layers = n_layers | |||||
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 | |||||
c_cur = stem_multiplier * self.channels | |||||
self.stem = nn.Sequential( | |||||
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), | |||||
nn.BatchNorm2d(c_cur) | |||||
) | |||||
# for the first cell, stem is used for both s0 and s1 | |||||
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. | |||||
channels_pp, channels_p, c_cur = c_cur, c_cur, channels | |||||
self.cells = nn.ModuleList() | |||||
reduction_p, reduction = False, False | |||||
for i in range(n_layers): | |||||
reduction_p, reduction = reduction, False | |||||
# Reduce featuremap size and double channels in 1/3 and 2/3 layer. | |||||
if i in [n_layers // 3, 2 * n_layers // 3]: | |||||
c_cur *= 2 | |||||
reduction = True | |||||
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) | |||||
self.cells.append(cell) | |||||
c_cur_out = c_cur * n_nodes | |||||
channels_pp, channels_p = channels_p, c_cur_out | |||||
if i == self.aux_pos: | |||||
self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) | |||||
self.gap = nn.AdaptiveAvgPool2d(1) | |||||
self.linear = nn.Linear(channels_p, n_classes) | |||||
def forward(self, x): | |||||
s0 = s1 = self.stem(x) | |||||
aux_logits = None | |||||
for i, cell in enumerate(self.cells): | |||||
s0, s1 = s1, cell(s0, s1) | |||||
if i == self.aux_pos and self.training: | |||||
aux_logits = self.aux_head(s1) | |||||
out = self.gap(s1) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
if aux_logits is not None: | |||||
return logits, aux_logits | |||||
return logits | |||||
def drop_path_prob(self, p): | |||||
for module in self.modules(): | |||||
if isinstance(module, ops.DropPath): | |||||
module.p = p |
@@ -0,0 +1,136 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch | |||||
import torch.nn as nn | |||||
class DropPath(nn.Module): | |||||
def __init__(self, p=0.): | |||||
""" | |||||
Drop path with probability. | |||||
Parameters | |||||
---------- | |||||
p : float | |||||
Probability of an path to be zeroed. | |||||
""" | |||||
super().__init__() | |||||
self.p = p | |||||
def forward(self, x): | |||||
if self.training and self.p > 0.: | |||||
keep_prob = 1. - self.p | |||||
# per data point mask | |||||
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob) | |||||
return x / keep_prob * mask | |||||
return x | |||||
class PoolBN(nn.Module): | |||||
""" | |||||
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. | |||||
""" | |||||
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): | |||||
super().__init__() | |||||
if pool_type.lower() == 'max': | |||||
self.pool = nn.MaxPool2d(kernel_size, stride, padding) | |||||
elif pool_type.lower() == 'avg': | |||||
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) | |||||
else: | |||||
raise ValueError() | |||||
self.bn = nn.BatchNorm2d(C, affine=affine) | |||||
def forward(self, x): | |||||
out = self.pool(x) | |||||
out = self.bn(out) | |||||
return out | |||||
class StdConv(nn.Module): | |||||
""" | |||||
Standard conv: ReLU - Conv - BN | |||||
""" | |||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(), | |||||
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), | |||||
nn.BatchNorm2d(C_out, affine=affine) | |||||
) | |||||
def forward(self, x): | |||||
return self.net(x) | |||||
class FacConv(nn.Module): | |||||
""" | |||||
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN | |||||
""" | |||||
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(), | |||||
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), | |||||
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), | |||||
nn.BatchNorm2d(C_out, affine=affine) | |||||
) | |||||
def forward(self, x): | |||||
return self.net(x) | |||||
class DilConv(nn.Module): | |||||
""" | |||||
(Dilated) depthwise separable conv. | |||||
ReLU - (Dilated) depthwise separable - Pointwise - BN. | |||||
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. | |||||
""" | |||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(), | |||||
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, | |||||
bias=False), | |||||
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), | |||||
nn.BatchNorm2d(C_out, affine=affine) | |||||
) | |||||
def forward(self, x): | |||||
return self.net(x) | |||||
class SepConv(nn.Module): | |||||
""" | |||||
Depthwise separable conv. | |||||
DilConv(dilation=1) * 2. | |||||
""" | |||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), | |||||
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) | |||||
) | |||||
def forward(self, x): | |||||
return self.net(x) | |||||
class FactorizedReduce(nn.Module): | |||||
""" | |||||
Reduce feature map size by factorized pointwise (stride=2). | |||||
""" | |||||
def __init__(self, C_in, C_out, affine=True): | |||||
super().__init__() | |||||
self.relu = nn.ReLU() | |||||
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | |||||
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | |||||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | |||||
def forward(self, x): | |||||
x = self.relu(x) | |||||
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) | |||||
out = self.bn(out) | |||||
return out |
@@ -0,0 +1,83 @@ | |||||
# train stage | |||||
`python darts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --layers 8 --model_lr 0.025 --arch_lr 3e-4 --epochs 1 --batch_size 64 --channels 16` | |||||
Note: | |||||
here `--epochs 2` just for debug | |||||
# select stage | |||||
`python darts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json' ` | |||||
# retrain stage | |||||
`python darts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 1 --lr 0.025 --layers 20 --channels 36` | |||||
# output file | |||||
`result.json` | |||||
``` | |||||
{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} | |||||
``` | |||||
`search_space.json` | |||||
``` | |||||
{ | |||||
"op_list": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"maxpool", | |||||
"avgpool", | |||||
"skipconnect", | |||||
"sepconv3x3", | |||||
"sepconv5x5", | |||||
"dilconv3x3", | |||||
"dilconv5x5", | |||||
"none" | |||||
] | |||||
}, | |||||
"search_space": { | |||||
"normal_n2_p0": "op_list", | |||||
"normal_n2_p1": "op_list", | |||||
"normal_n2_switch": { | |||||
"_type": "input_choice", | |||||
"_value": { | |||||
"candidates": [ | |||||
"normal_n2_p0", | |||||
"normal_n2_p1" | |||||
], | |||||
"n_chosen": 2 | |||||
} | |||||
}, | |||||
... | |||||
} | |||||
``` | |||||
`best_selected_space.json` | |||||
``` | |||||
{ | |||||
"normal_n2_p0": "dilconv5x5", | |||||
"normal_n2_p1": "dilconv5x5", | |||||
"normal_n2_switch": [ | |||||
"normal_n2_p0", | |||||
"normal_n2_p1" | |||||
], | |||||
"normal_n3_p0": "sepconv3x3", | |||||
"normal_n3_p1": "dilconv5x5", | |||||
"normal_n3_p2": [], | |||||
"normal_n3_switch": [ | |||||
"normal_n3_p0", | |||||
"normal_n3_p1" | |||||
], | |||||
"normal_n4_p0": [], | |||||
"normal_n4_p1": "dilconv5x5", | |||||
"normal_n4_p2": "sepconv5x5", | |||||
"normal_n4_p3": [], | |||||
"normal_n4_switch": [ | |||||
"normal_n4_p1", | |||||
"normal_n4_p2" | |||||
], | |||||
... | |||||
} | |||||
``` |
@@ -0,0 +1,21 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
def accuracy(output, target, topk=(1,)): | |||||
""" Computes the precision@k for the specified values of k """ | |||||
maxk = max(topk) | |||||
batch_size = target.size(0) | |||||
_, pred = output.topk(maxk, 1, True, True) | |||||
pred = pred.t() | |||||
# one-hot case | |||||
if target.ndimension() > 1: | |||||
target = target.max(1)[1] | |||||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |||||
res = dict() | |||||
for k in topk: | |||||
correct_k = correct[:k].reshape(-1).float().sum(0) | |||||
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() | |||||
return res |
@@ -0,0 +1,331 @@ | |||||
from nvidia.dali.pipeline import Pipeline | |||||
from nvidia.dali import ops | |||||
from nvidia.dali import types | |||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
class HybridTrainPipeline(Pipeline): | |||||
def __init__(self, batch_size, file_root, num_threads, device_id, num_shards, shard_id): | |||||
super(HybridTrainPipeline, self).__init__(batch_size, num_threads, device_id) | |||||
device_type = {0:"cpu"} | |||||
if num_shards == 0: | |||||
self.input = ops.FileReader(file_root = file_root) | |||||
else: | |||||
self.input = ops.FileReader(file_root = file_root, num_shards = num_shards, shard_id = shard_id) | |||||
# ##### 可自由更改 ################################### | |||||
self.decode = ops.ImageDecoder(device = device_type.get(num_shards, "mixed"), output_type = types.RGB) | |||||
self.res = ops.RandomResizedCrop(device=device_type.get(num_shards, "gpu"), size = 224) | |||||
self.cmnp = ops.CropMirrorNormalize(device=device_type.get(num_shards, "gpu"), | |||||
dtype = types.FLOAT, # output_dtype=types.FLOAT, | |||||
output_layout=types.NCHW, | |||||
mean=0. ,# if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], | |||||
std=1. )# if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) | |||||
# #################################################### | |||||
def define_graph(self, ): | |||||
jpegs, labels = self.input(name="Reader") | |||||
images = self.decode(jpegs) | |||||
images = self.res(images) | |||||
images = self.cmnp(images) | |||||
return images, labels | |||||
class HybridValPipeline(Pipeline): | |||||
def __init__(self, batch_size, file_root, num_threads, device_id, num_shards, shard_id): | |||||
super(HybridValPipeline, self).__init__(batch_size, num_threads, device_id) | |||||
device_type = {0:"cpu"} | |||||
if num_shards == 0: | |||||
self.input = ops.FileReader(file_root = file_root) | |||||
else: | |||||
self.input = ops.FileReader(file_root = file_root, num_shards = num_shards, shard_id = shard_id) | |||||
# ##### 可自由更改 ################################### | |||||
self.decode = ops.ImageDecoder(device = device_type.get(num_shards, "mixed"), output_type = types.RGB) | |||||
self.res = ops.RandomResizedCrop(device=device_type.get(num_shards, "gpu"), size = 224) | |||||
self.cmnp = ops.CropMirrorNormalize(device=device_type.get(num_shards, "gpu"), | |||||
dtype = types.FLOAT, # output_dtype=types.FLOAT, | |||||
output_layout=types.NCHW, | |||||
mean=0. ,# if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255], | |||||
std=1. )# if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255]) | |||||
# #################################################### | |||||
def define_graph(self, ): | |||||
jpegs, labels = self.input(name="Reader") | |||||
images = self.decode(jpegs) | |||||
images = self.res(images) | |||||
images = self.cmnp(images) | |||||
return images, labels | |||||
class TorchWrapper: | |||||
""" | |||||
将多个pipeline封装为一个iterator | |||||
parameters: | |||||
num_shards : int 显卡并行数 | |||||
data_loader : dali.pipeline.Pipeline类型 经过pipeline处理的数据结果 | |||||
iter_mode : str recursion, iter 指定多个pipeline合并的方式,默认recursion | |||||
""" | |||||
def __init__(self, num_shards, data_loader, iter_mode = "recursion"): | |||||
self.index = 0 | |||||
self.count = 0 | |||||
self.num_shards = num_shards | |||||
self.data_loader = data_loader | |||||
self.iter_mode = iter_mode | |||||
if self.iter_mode not in {"recursion", "iter"}: | |||||
raise Exception("iter_mode should be either 'recursion' or 'iter'") | |||||
def __iter__(self,): | |||||
return self | |||||
def __len__(self, ): | |||||
# 返回样本总量,而非batch_num | |||||
if num_shards == 0: | |||||
return self.data_loader.size | |||||
else: | |||||
return len(self.data_loader)*self.data_loader[0].size | |||||
def __next__(self, ): | |||||
if num_shards == 0: | |||||
# 不使用GPU | |||||
data = next(self.data_loader) | |||||
return data[0]["data"], data[0]["label"].view(-1).long() | |||||
else: | |||||
# 使用一块或多块GPU | |||||
if self.iter_mode == "recursion": | |||||
return self._get_next_recursion() | |||||
elif self.iter_mode == "iter": | |||||
return self._get_next_iter(self.data_loader[0]) | |||||
def _get_next_iter(self, data_loader): | |||||
if self.count == data_loader.size: | |||||
self.index+=1 | |||||
data_loader = self.data_loader[self.index] | |||||
self.count+=1 | |||||
data = next(data_loader) | |||||
return data[0]["data"], data[0]["label"].view(-1).long() | |||||
def _get_next_recursion(self, ): | |||||
self.index = self.count%self.num_shards | |||||
self.count+=1 | |||||
data_loader = self.data_loader[self.index] | |||||
data = next(data_loader) | |||||
return data[0]["data"], data[0]["label"].view(-1).long() | |||||
def get_iter_dali_cuda(batch_size=256, train_file_root="", val_file_root="", num_threads=4, device_id=[-1], num_shards=0, shard_id=[-1]): | |||||
""" | |||||
获取可用于pytorch训练的数据迭代器 | |||||
数据的读取和处理部分可以使用多张GPU来完成 | |||||
1、创建dali pipeline | |||||
2、封装为适用于pytorch的数据迭代器 | |||||
3、将多卡的各个pipeline封装在一起 | |||||
4、数据输出在cpu端,在cuda中 | |||||
数据需要保证如下形式: | |||||
images | |||||
|-file_list.txt | |||||
|-images/dog | |||||
|-dog_4.jpg | |||||
|-dog_5.jpg | |||||
|-dog_9.jpg | |||||
|-dog_6.jpg | |||||
|-dog_3.jpg | |||||
|-images/kitten | |||||
|-cat_10.jpg | |||||
|-cat_5.jpg | |||||
|-cat_9.jpg | |||||
|-cat_8.jpg | |||||
|-cat_1.jpg | |||||
parameters: | |||||
batch_size : int 每批数据的量 | |||||
file_root : str 数据的路径 | |||||
num_threads : int 读取数据的CPU线程数 | |||||
device_id : list of int GPU的物理编号 | |||||
shard_id : list of int GPU的虚拟编号 | |||||
num_shard : int | |||||
methods: | |||||
get_train_pipeline(shard_id, device_id) : 创建dali的pipeline,用以读取并处理训练数据 | |||||
get_val_pipeline(shard_id, device_id) : 创建dali的pipeline,用以读取并处理验证数据 | |||||
get_dali_iter_for_torch(piplines, data_num) : 封装成可用于pytorch的数据迭代器 | |||||
get_data_size(pipeline) : 计算每个pipeline实际输出的数据总量,数据总量是文件中的数据量,实际输出是去掉了不满一个批次大小的数据 | |||||
例: | |||||
# 分别从TRAIN_PATH和VAL_PATH读取训练和验证数据,batch_size选择256,启动4个线程来读取数据,用2块GPU处理数据,分别是第0号和第4号GPU | |||||
# 程序默认使用所有显卡,和4线程 | |||||
# 如果使用单张GPU,请设置num_shards = 1, shard_id = [0], device_id保持一个列表形式 | |||||
# 如果不使用GPU,请使用get_iter_dali_cpu() | |||||
train_data_iter, val_data_iter = get_iter_dali(batch_size=256, | |||||
train_file_root=TRAIN_PATH, | |||||
val_file_root=Val_PATH, | |||||
num_threads=4, | |||||
device_id=[0,4], | |||||
num_shards=2, | |||||
shard_id=[0,1]) | |||||
# 在torch中训练 | |||||
torch_model = TorchModel(para) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.Adam(torch_model.parameters()) | |||||
for epoch in range(epoches): | |||||
for step, x,y in enumerate(train_data_iter): | |||||
# 数据 : x | |||||
# 标签 : y | |||||
x = x.to("cuda:0") | |||||
y = y.to("cuda:0") | |||||
output = my_model(x) | |||||
optimizer.zero_grad() | |||||
loss = criterion(output, y) | |||||
loss.backward() | |||||
optimizer.step() | |||||
... | |||||
... | |||||
""" | |||||
def get_train_pipeline(shard_id, device_id): | |||||
pipeline = HybridTrainPipeline(batch_size = batch_size, | |||||
file_root = train_file_root, | |||||
num_threads = num_threads, | |||||
num_shards = num_shards, | |||||
shard_id = shard_id, | |||||
device_id = device_id) | |||||
return pipeline | |||||
def get_val_pipeline(shard_id, device_id): | |||||
pipeline = HybridValPipeline(batch_size = batch_size, | |||||
file_root = val_file_root, | |||||
num_threads = num_threads, | |||||
num_shards = num_shards, | |||||
shard_id = shard_id, | |||||
device_id = device_id) | |||||
return pipeline | |||||
pipeline_for_train = [get_train_pipeline(shard_id = shard_id_index, device_id = device_id_index) \ | |||||
for shard_id_index, device_id_index in zip(shard_id, device_id)] | |||||
pipeline_for_val = [get_val_pipeline(shard_id = shard_id_index, device_id = device_id_index) \ | |||||
for shard_id_index, device_id_index in zip(shard_id, device_id)] | |||||
[pipeline.build() for pipeline in pipeline_for_train] | |||||
[pipeline.build() for pipeline in pipeline_for_val] | |||||
def get_data_size(pipeline): | |||||
data_num = pipeline.epoch_size()["Reader"] | |||||
batch_size = pipeline.batch_size | |||||
return data_num//batch_size*batch_size | |||||
data_num_train = get_data_size(pipeline_for_train[0]) | |||||
data_num_val = get_data_size(pipeline_for_val[0]) | |||||
def get_dali_iter_for_torch(pipelines, data_num): | |||||
return [DALIClassificationIterator(pipelines=pipeline, | |||||
last_batch_policy="drop",size = data_num) for pipeline in pipelines] | |||||
data_loader_train = get_dali_iter_for_torch(pipeline_for_train, data_num_train) | |||||
data_loader_val = get_dali_iter_for_torch(pipeline_for_val, data_num_val) | |||||
train_data_iter = TorchWrapper(num_shards, data_loader_train) | |||||
val_data_iter = TorchWrapper(num_shards, data_loader_val) | |||||
return train_data_iter, val_data_iter | |||||
def get_iter_dali_cpu(batch_size=256, train_file_root="", val_file_root="", num_threads=4): | |||||
pipeline_train = HybridTrainPipeline(batch_size = batch_size, | |||||
file_root = train_file_root, | |||||
num_threads = num_threads, | |||||
num_shards = 0, | |||||
shard_id = -1, | |||||
device_id = 0) | |||||
pipeline_val = HybridTrainPipeline(batch_size = batch_size, | |||||
file_root = val_file_root, | |||||
num_threads = num_threads, | |||||
num_shards = 0, | |||||
shard_id = -1, | |||||
device_id = 0) | |||||
pipeline_train.build() | |||||
pipeline_val.build() | |||||
def get_data_size(pipeline): | |||||
data_num = pipeline.epoch_size()["Reader"] | |||||
batch_size = pipeline.batch_size | |||||
return data_num//batch_size*batch_size | |||||
data_num_train = get_data_size(pipeline_train) | |||||
data_num_val = get_data_size(pipeline_val) | |||||
data_loader_train = DALIClassificationIterator(pipelines=pipeline_train, | |||||
last_batch_policy="drop",size = data_num_train) | |||||
data_loader_val = DALIClassificationIterator(pipelines=pipeline_val, | |||||
last_batch_policy="drop",size = data_num_val) | |||||
train_data_iter = TorchWrapper(0,data_loader_train) | |||||
val_data_iter = TorchWrapper(0,data_loader_val) | |||||
return train_data_iter, val_data_iter | |||||
if __name__ == "__main__": | |||||
PATH = "./imagenet" | |||||
TRAIN_PATH = "./imagenet/train" | |||||
VALID_PATH = "./imagenet/val" | |||||
train_data_iter_cuda, val_data_iter_cuda = get_iter_dali_cuda(batch_size=256, | |||||
train_file_root=TRAIN_PATH, | |||||
val_file_root=TRAIN_PATH, | |||||
num_threads=4, | |||||
device_id=[0,4], | |||||
num_shards=2, | |||||
shard_id=[0,1]) | |||||
train_data_iter_cpu, val_data_iter_cpu = get_iter_dali_cpu(batch_size=256, | |||||
train_file_root=TRAIN_PATH, | |||||
val_file_root=TRAIN_PATH, | |||||
num_threads=4) |
@@ -0,0 +1,46 @@ | |||||
# Efficient Neural Architecture Search (ENAS) | |||||
## 1. Requirements | |||||
``` | |||||
torch | |||||
torchvision | |||||
collections | |||||
argparser | |||||
pickle | |||||
pytest-shutil | |||||
``` | |||||
## 2.Train | |||||
### Stage1: search an architecture | |||||
* macro search | |||||
``` | |||||
python trainer.py --trial_id=0 --search_for macro --best_selected_space_path='./macro_selected_space.json' --result_path='./macro_result.json' | |||||
``` | |||||
* micro search | |||||
``` | |||||
python trainer.py --trial_id=0 --search_for micro --best_selected_space_path='./micro_selected_space.json' --result_path='./micro_result.json' | |||||
``` | |||||
### Stage2: select (deprecated) | |||||
``` | |||||
python selector.py | |||||
``` | |||||
### stage3: retrain | |||||
* macro search | |||||
``` | |||||
python retrainer.py --search_for macro --best_checkpoint_dir='./macro_checkpoint.pth' --best_selected_space_path= | |||||
'./macro_selected_space.json' --result_path='./macro_result.json' | |||||
``` | |||||
* micro search | |||||
``` | |||||
python retrainer.py --search_for micro --best_checkpoint_dir='./micro_checkpoint.pth' --best_selected_space_path= | |||||
'./micro_selected_space.json' --result_path='./micro_result.json' | |||||
``` |
@@ -0,0 +1,5 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from .mutator import EnasMutator | |||||
from .trainer import EnasTrainer |
@@ -0,0 +1,28 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from torchvision import transforms | |||||
from torchvision.datasets import CIFAR10 | |||||
def get_dataset(cls,datadir): | |||||
MEAN = [0.49139968, 0.48215827, 0.44653124] | |||||
STD = [0.24703233, 0.24348505, 0.26158768] | |||||
transf = [ | |||||
transforms.RandomCrop(32, padding=4), | |||||
transforms.RandomHorizontalFlip() | |||||
] | |||||
normalize = [ | |||||
transforms.ToTensor(), | |||||
transforms.Normalize(MEAN, STD) | |||||
] | |||||
train_transform = transforms.Compose(transf + normalize) | |||||
valid_transform = transforms.Compose(normalize) | |||||
if cls == "cifar10": | |||||
dataset_train = CIFAR10(root=datadir, train=True, download=True, transform=train_transform) | |||||
dataset_valid = CIFAR10(root=datadir, train=False, download=True, transform=valid_transform) | |||||
else: | |||||
raise NotImplementedError | |||||
return dataset_train, dataset_valid |
@@ -0,0 +1,87 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch.nn as nn | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from pytorch import mutables # LayerChoice, InputChoice, MutableScope | |||||
from ops import FactorizedReduce, ConvBranch, PoolBranch | |||||
class ENASLayer(mutables.MutableScope): | |||||
def __init__(self, key, prev_labels, in_filters, out_filters): | |||||
super().__init__(key) | |||||
self.in_filters = in_filters | |||||
self.out_filters = out_filters | |||||
self.mutable = mutables.LayerChoice([ | |||||
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False), | |||||
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True), | |||||
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False), | |||||
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True), | |||||
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), | |||||
PoolBranch('max', in_filters, out_filters, 3, 1, 1) | |||||
]) | |||||
if len(prev_labels) > 0: | |||||
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) | |||||
else: | |||||
self.skipconnect = None | |||||
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) | |||||
def forward(self, prev_layers): | |||||
out = self.mutable(prev_layers[-1]) | |||||
if self.skipconnect is not None: | |||||
connection = self.skipconnect(prev_layers[:-1]) | |||||
if connection is not None: | |||||
out += connection | |||||
return self.batch_norm(out) | |||||
class GeneralNetwork(nn.Module): | |||||
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10, | |||||
dropout_rate=0.0): | |||||
super().__init__() | |||||
self.num_layers = num_layers | |||||
self.num_classes = num_classes | |||||
self.out_filters = out_filters | |||||
self.stem = nn.Sequential( | |||||
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False), | |||||
nn.BatchNorm2d(out_filters) | |||||
) | |||||
pool_distance = self.num_layers // 3 | |||||
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1] | |||||
self.dropout_rate = dropout_rate | |||||
self.dropout = nn.Dropout(self.dropout_rate) | |||||
self.layers = nn.ModuleList() | |||||
self.pool_layers = nn.ModuleList() | |||||
labels = [] | |||||
for layer_id in range(self.num_layers): | |||||
labels.append("layer_{}".format(layer_id)) | |||||
if layer_id in self.pool_layers_idx: | |||||
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) | |||||
self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters)) | |||||
self.gap = nn.AdaptiveAvgPool2d(1) | |||||
self.dense = nn.Linear(self.out_filters, self.num_classes) | |||||
def forward(self, x): | |||||
bs = x.size(0) | |||||
cur = self.stem(x) | |||||
layers = [cur] | |||||
for layer_id in range(self.num_layers): | |||||
cur = self.layers[layer_id](layers) | |||||
layers.append(cur) | |||||
if layer_id in self.pool_layers_idx: | |||||
for i, layer in enumerate(layers): | |||||
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) | |||||
cur = layers[-1] | |||||
cur = self.gap(cur).view(bs, -1) | |||||
cur = self.dropout(cur) | |||||
logits = self.dense(cur) | |||||
return logits |
@@ -0,0 +1,187 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from pytorch import mutables | |||||
from ops import FactorizedReduce, StdConv, SepConvBN, Pool | |||||
class AuxiliaryHead(nn.Module): | |||||
def __init__(self, in_channels, num_classes): | |||||
super().__init__() | |||||
self.in_channels = in_channels | |||||
self.num_classes = num_classes | |||||
self.pooling = nn.Sequential( | |||||
nn.ReLU(), | |||||
nn.AvgPool2d(5, 3, 2) | |||||
) | |||||
self.proj = nn.Sequential( | |||||
StdConv(in_channels, 128), | |||||
StdConv(128, 768) | |||||
) | |||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||||
self.fc = nn.Linear(768, 10, bias=False) | |||||
def forward(self, x): | |||||
bs = x.size(0) | |||||
x = self.pooling(x) | |||||
x = self.proj(x) | |||||
x = self.avg_pool(x).view(bs, -1) | |||||
x = self.fc(x) | |||||
return x | |||||
class Cell(nn.Module): | |||||
def __init__(self, cell_name, prev_labels, channels): | |||||
super().__init__() | |||||
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, | |||||
key=cell_name + "_input") | |||||
self.op_choice = mutables.LayerChoice([ | |||||
SepConvBN(channels, channels, 3, 1), | |||||
SepConvBN(channels, channels, 5, 2), | |||||
Pool("avg", 3, 1, 1), | |||||
Pool("max", 3, 1, 1), | |||||
nn.Identity() | |||||
], key=cell_name + "_op") | |||||
def forward(self, prev_layers): | |||||
chosen_input, chosen_mask = self.input_choice(prev_layers) | |||||
cell_out = self.op_choice(chosen_input) | |||||
return cell_out, chosen_mask | |||||
class Node(mutables.MutableScope): | |||||
def __init__(self, node_name, prev_node_names, channels): | |||||
super().__init__(node_name) | |||||
self.cell_x = Cell(node_name + "_x", prev_node_names, channels) | |||||
self.cell_y = Cell(node_name + "_y", prev_node_names, channels) | |||||
def forward(self, prev_layers): | |||||
out_x, mask_x = self.cell_x(prev_layers) | |||||
out_y, mask_y = self.cell_y(prev_layers) | |||||
return out_x + out_y, mask_x | mask_y | |||||
class Calibration(nn.Module): | |||||
def __init__(self, in_channels, out_channels): | |||||
super().__init__() | |||||
self.process = None | |||||
if in_channels != out_channels: | |||||
self.process = StdConv(in_channels, out_channels) | |||||
def forward(self, x): | |||||
if self.process is None: | |||||
return x | |||||
return self.process(x) | |||||
class ReductionLayer(nn.Module): | |||||
def __init__(self, in_channels_pp, in_channels_p, out_channels): | |||||
super().__init__() | |||||
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) | |||||
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) | |||||
def forward(self, pprev, prev): | |||||
return self.reduce0(pprev), self.reduce1(prev) | |||||
class ENASLayer(nn.Module): | |||||
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): | |||||
super().__init__() | |||||
self.preproc0 = Calibration(in_channels_pp, out_channels) | |||||
self.preproc1 = Calibration(in_channels_p, out_channels) | |||||
self.num_nodes = num_nodes | |||||
name_prefix = "reduce" if reduction else "normal" | |||||
self.nodes = nn.ModuleList() | |||||
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] | |||||
for i in range(num_nodes): | |||||
node_labels.append("{}_node_{}".format(name_prefix, i)) | |||||
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) | |||||
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) | |||||
self.bn = nn.BatchNorm2d(out_channels, affine=False) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
nn.init.kaiming_normal_(self.final_conv_w) | |||||
def forward(self, pprev, prev): | |||||
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) | |||||
prev_nodes_out = [pprev_, prev_] | |||||
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) | |||||
for i in range(self.num_nodes): | |||||
node_out, mask = self.nodes[i](prev_nodes_out) | |||||
nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device) | |||||
prev_nodes_out.append(node_out) | |||||
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) | |||||
unused_nodes = F.relu(unused_nodes) | |||||
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] | |||||
conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1) | |||||
out = F.conv2d(unused_nodes, conv_weight) | |||||
return prev, self.bn(out) | |||||
class MicroNetwork(nn.Module): | |||||
def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10, | |||||
dropout_rate=0.0, use_aux_heads=False): | |||||
super().__init__() | |||||
self.num_layers = num_layers | |||||
self.use_aux_heads = use_aux_heads | |||||
self.stem = nn.Sequential( | |||||
nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False), | |||||
nn.BatchNorm2d(out_channels * 3) | |||||
) | |||||
pool_distance = self.num_layers // 3 | |||||
pool_layers = [pool_distance, 2 * pool_distance + 1] | |||||
self.dropout = nn.Dropout(dropout_rate) | |||||
self.layers = nn.ModuleList() | |||||
c_pp = c_p = out_channels * 3 | |||||
c_cur = out_channels | |||||
for layer_id in range(self.num_layers + 2): | |||||
reduction = False | |||||
if layer_id in pool_layers: | |||||
c_cur, reduction = c_p * 2, True | |||||
self.layers.append(ReductionLayer(c_pp, c_p, c_cur)) | |||||
c_pp = c_p = c_cur | |||||
self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction)) | |||||
if self.use_aux_heads and layer_id == pool_layers[-1] + 1: | |||||
self.layers.append(AuxiliaryHead(c_cur, num_classes)) | |||||
c_pp, c_p = c_p, c_cur | |||||
self.gap = nn.AdaptiveAvgPool2d(1) | |||||
self.dense = nn.Linear(c_cur, num_classes) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
nn.init.kaiming_normal_(m.weight) | |||||
def forward(self, x): | |||||
bs = x.size(0) | |||||
prev = cur = self.stem(x) | |||||
aux_logits = None | |||||
for layer in self.layers: | |||||
if isinstance(layer, AuxiliaryHead): | |||||
if self.training: | |||||
aux_logits = layer(cur) | |||||
else: | |||||
prev, cur = layer(prev, cur) | |||||
cur = self.gap(F.relu(cur)).view(bs, -1) | |||||
cur = self.dropout(cur) | |||||
logits = self.dense(cur) | |||||
if aux_logits is not None: | |||||
return logits, aux_logits | |||||
return logits |
@@ -0,0 +1,197 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from pytorch.mutator import Mutator | |||||
from pytorch.mutables import LayerChoice, InputChoice, MutableScope | |||||
class StackedLSTMCell(nn.Module): | |||||
def __init__(self, layers, size, bias): | |||||
super().__init__() | |||||
self.lstm_num_layers = layers | |||||
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias) | |||||
for _ in range(self.lstm_num_layers)]) | |||||
def forward(self, inputs, hidden): | |||||
prev_c, prev_h = hidden | |||||
next_c, next_h = [], [] | |||||
for i, m in enumerate(self.lstm_modules): | |||||
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) | |||||
next_c.append(curr_c) | |||||
next_h.append(curr_h) | |||||
# current implementation only supports batch size equals 1, | |||||
# but the algorithm does not necessarily have this limitation | |||||
inputs = curr_h[-1].view(1, -1) | |||||
return next_c, next_h | |||||
class EnasMutator(Mutator): | |||||
""" | |||||
A mutator that mutates the graph with RL. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model. | |||||
lstm_size : int | |||||
Controller LSTM hidden units. | |||||
lstm_num_layers : int | |||||
Number of layers for stacked LSTM. | |||||
tanh_constant : float | |||||
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. | |||||
cell_exit_extra_step : bool | |||||
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state | |||||
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper. | |||||
skip_target : float | |||||
Target probability that skipconnect will appear. | |||||
temperature : float | |||||
Temperature constant that divides the logits. | |||||
branch_bias : float | |||||
Manual bias applied to make some operations more likely to be chosen. | |||||
Currently this is implemented with a hardcoded match rule that aligns with original repo. | |||||
If a mutable has a ``reduce`` in its key, all its op choices | |||||
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others | |||||
receive a bias of ``-self.branch_bias``. | |||||
entropy_reduction : str | |||||
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced. | |||||
""" | |||||
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, | |||||
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"): | |||||
super().__init__(model) | |||||
self.lstm_size = lstm_size | |||||
self.lstm_num_layers = lstm_num_layers | |||||
self.tanh_constant = tanh_constant | |||||
self.temperature = temperature | |||||
self.cell_exit_extra_step = cell_exit_extra_step | |||||
self.skip_target = skip_target | |||||
self.branch_bias = branch_bias | |||||
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) | |||||
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) | |||||
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) | |||||
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) | |||||
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) | |||||
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable | |||||
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean." | |||||
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean | |||||
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") | |||||
self.bias_dict = nn.ParameterDict() | |||||
self.max_layer_choice = 0 | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
if self.max_layer_choice == 0: | |||||
self.max_layer_choice = len(mutable) | |||||
assert self.max_layer_choice == len(mutable), \ | |||||
"ENAS mutator requires all layer choice have the same number of candidates." | |||||
# We are judging by keys and module types to add biases to layer choices. Needs refactor. | |||||
if "reduce" in mutable.key: | |||||
def is_conv(choice): | |||||
return "conv" in str(type(choice)).lower() | |||||
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable | |||||
for choice in mutable]) | |||||
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) | |||||
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) | |||||
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) | |||||
def sample_search(self): | |||||
self._initialize() | |||||
self._sample(self.mutables) | |||||
return self._choices | |||||
def sample_final(self): | |||||
return self.sample_search() | |||||
def _sample(self, tree): | |||||
mutable = tree.mutable | |||||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: | |||||
self._choices[mutable.key] = self._sample_layer_choice(mutable) | |||||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: | |||||
self._choices[mutable.key] = self._sample_input_choice(mutable) | |||||
for child in tree.children: | |||||
self._sample(child) | |||||
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: | |||||
if self.cell_exit_extra_step: | |||||
self._lstm_next_step() | |||||
self._mark_anchor(mutable.key) | |||||
def _initialize(self): | |||||
self._choices = dict() | |||||
self._anchors_hid = dict() | |||||
self._inputs = self.g_emb.data | |||||
self._c = [torch.zeros((1, self.lstm_size), | |||||
dtype=self._inputs.dtype, | |||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)] | |||||
self._h = [torch.zeros((1, self.lstm_size), | |||||
dtype=self._inputs.dtype, | |||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)] | |||||
self.sample_log_prob = 0 | |||||
self.sample_entropy = 0 | |||||
self.sample_skip_penalty = 0 | |||||
def _lstm_next_step(self): | |||||
self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) | |||||
def _mark_anchor(self, key): | |||||
self._anchors_hid[key] = self._h[-1] | |||||
def _sample_layer_choice(self, mutable): | |||||
self._lstm_next_step() | |||||
logit = self.soft(self._h[-1]) | |||||
if self.temperature is not None: | |||||
logit /= self.temperature | |||||
if self.tanh_constant is not None: | |||||
logit = self.tanh_constant * torch.tanh(logit) | |||||
if mutable.key in self.bias_dict: | |||||
logit += self.bias_dict[mutable.key] | |||||
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) | |||||
log_prob = self.cross_entropy_loss(logit, branch_id) | |||||
self.sample_log_prob += self.entropy_reduction(log_prob) | |||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type | |||||
self.sample_entropy += self.entropy_reduction(entropy) | |||||
self._inputs = self.embedding(branch_id) | |||||
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) | |||||
def _sample_input_choice(self, mutable): | |||||
query, anchors = [], [] | |||||
for label in mutable.choose_from: | |||||
if label not in self._anchors_hid: | |||||
self._lstm_next_step() | |||||
self._mark_anchor(label) # empty loop, fill not found | |||||
query.append(self.attn_anchor(self._anchors_hid[label])) | |||||
anchors.append(self._anchors_hid[label]) | |||||
query = torch.cat(query, 0) | |||||
query = torch.tanh(query + self.attn_query(self._h[-1])) | |||||
query = self.v_attn(query) | |||||
if self.temperature is not None: | |||||
query /= self.temperature | |||||
if self.tanh_constant is not None: | |||||
query = self.tanh_constant * torch.tanh(query) | |||||
if mutable.n_chosen is None: | |||||
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type | |||||
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) | |||||
skip_prob = torch.sigmoid(logit) | |||||
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) | |||||
self.sample_skip_penalty += kl | |||||
log_prob = self.cross_entropy_loss(logit, skip) | |||||
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) | |||||
else: | |||||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." | |||||
logit = query.view(1, -1) | |||||
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) | |||||
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) | |||||
log_prob = self.cross_entropy_loss(logit, index) | |||||
self._inputs = anchors[index.item()] | |||||
self.sample_log_prob += self.entropy_reduction(log_prob) | |||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type | |||||
self.sample_entropy += self.entropy_reduction(entropy) | |||||
return skip.bool() |
@@ -0,0 +1,129 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch | |||||
import torch.nn as nn | |||||
class StdConv(nn.Module): | |||||
def __init__(self, C_in, C_out): | |||||
super(StdConv, self).__init__() | |||||
self.conv = nn.Sequential( | |||||
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), | |||||
nn.BatchNorm2d(C_out, affine=False), | |||||
nn.ReLU() | |||||
) | |||||
def forward(self, x): | |||||
return self.conv(x) | |||||
def __str__(self): | |||||
return 'StdConv' | |||||
class PoolBranch(nn.Module): | |||||
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): | |||||
super().__init__() | |||||
self.kernel_size = kernel_size | |||||
self.pool_type = pool_type | |||||
self.preproc = StdConv(C_in, C_out) | |||||
self.pool = Pool(pool_type, kernel_size, stride, padding) | |||||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | |||||
def forward(self, x): | |||||
out = self.preproc(x) | |||||
out = self.pool(out) | |||||
out = self.bn(out) | |||||
return out | |||||
def __str__(self): | |||||
return '{}PoolBranch_{}'.format(self.pool_type, self.kernel_size) | |||||
class SeparableConv(nn.Module): | |||||
def __init__(self, C_in, C_out, kernel_size, stride, padding): | |||||
self.kernel_size = kernel_size | |||||
super(SeparableConv, self).__init__() | |||||
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, | |||||
groups=C_in, bias=False) | |||||
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) | |||||
def forward(self, x): | |||||
out = self.depthwise(x) | |||||
out = self.pointwise(out) | |||||
return out | |||||
def __str__(self): | |||||
return 'SeparableConv_{}'.format(self.kernel_size) | |||||
class ConvBranch(nn.Module): | |||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): | |||||
super(ConvBranch, self).__init__() | |||||
self.kernel_size = kernel_size | |||||
self.preproc = StdConv(C_in, C_out) | |||||
if separable: | |||||
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) | |||||
else: | |||||
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) | |||||
self.postproc = nn.Sequential( | |||||
nn.BatchNorm2d(C_out, affine=False), | |||||
nn.ReLU() | |||||
) | |||||
def forward(self, x): | |||||
out = self.preproc(x) | |||||
out = self.conv(out) | |||||
out = self.postproc(out) | |||||
return out | |||||
def __str__(self): | |||||
return 'ConvBranch_{}'.format(self.kernel_size) | |||||
class FactorizedReduce(nn.Module): | |||||
def __init__(self, C_in, C_out, affine=False): | |||||
super().__init__() | |||||
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | |||||
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | |||||
self.bn = nn.BatchNorm2d(C_out, affine=affine) | |||||
def forward(self, x): | |||||
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) | |||||
out = self.bn(out) | |||||
return out | |||||
def __str__(self): | |||||
return 'FactorizedReduce' | |||||
class Pool(nn.Module): | |||||
def __init__(self, pool_type, kernel_size, stride, padding): | |||||
super().__init__() | |||||
self.kernel_size = kernel_size | |||||
self.pool_type = pool_type | |||||
if pool_type.lower() == 'max': | |||||
self.pool = nn.MaxPool2d(kernel_size, stride, padding) | |||||
elif pool_type.lower() == 'avg': | |||||
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) | |||||
else: | |||||
raise ValueError() | |||||
def forward(self, x): | |||||
return self.pool(x) | |||||
def __str__(self): | |||||
return '{}Pool_{}'.format(self.pool_type, self.kernel_size) | |||||
class SepConvBN(nn.Module): | |||||
def __init__(self, C_in, C_out, kernel_size, padding): | |||||
super().__init__() | |||||
self.kernel_size = kernel_size | |||||
self.relu = nn.ReLU() | |||||
self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding) | |||||
self.bn = nn.BatchNorm2d(C_out, affine=True) | |||||
def forward(self, x): | |||||
x = self.relu(x) | |||||
x = self.conv(x) | |||||
x = self.bn(x) | |||||
return x | |||||
def __str__(self): | |||||
return 'SepConvBN_{}'.format(self.kernel_size) |
@@ -0,0 +1,490 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
import pickle | |||||
import shutil | |||||
import random | |||||
import math | |||||
import time | |||||
import datetime | |||||
import argparse | |||||
import distutils.util | |||||
import numpy as np | |||||
import json | |||||
import torch | |||||
from torch import nn | |||||
from torch import optim | |||||
from torch.utils.data import DataLoader | |||||
import torch.nn.functional as Func | |||||
from macro import GeneralNetwork | |||||
from micro import MicroNetwork | |||||
import datasets | |||||
from utils import accuracy, reward_accuracy | |||||
from pytorch.fixed import apply_fixed_architecture | |||||
from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint | |||||
logger = logging.getLogger("enas-retrain") | |||||
# TODO: | |||||
def set_random_seed(seed): | |||||
logger.info("set random seed for data reading: {}".format(seed)) | |||||
random.seed(seed) | |||||
os.environ['PYTHONHASHSEED'] = str(seed) | |||||
np.random.seed(seed) | |||||
random.seed(seed) | |||||
torch.manual_seed_all(seed) | |||||
if FLAGS.is_cuda: | |||||
torch.cuda.manual_seed_all(seed) | |||||
torch.backends.cudnn.deterministic = True | |||||
# TODO: parser args | |||||
def parse_args(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument( | |||||
"--data_dir", | |||||
type=str, | |||||
default="./data", | |||||
help="Directory containing the dataset and embedding file. (default: %(default)s)") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search_space directory") | |||||
parser.add_argument( | |||||
"--selected_space_path", | |||||
type=str, | |||||
default="./selected_space.json", | |||||
# required=True, | |||||
help="Architecture json file. (default: %(default)s)") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="res directory") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument( | |||||
"--output_dir", | |||||
type=str, | |||||
default="./output", | |||||
help="The output directory. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--best_checkpoint_dir", | |||||
type=str, | |||||
default="best_checkpoint", | |||||
help="Path for saved checkpoints. (default: %(default)s)") | |||||
parser.add_argument("--search_for", | |||||
choices=["macro", "micro"], | |||||
default="micro") | |||||
parser.add_argument( | |||||
"--batch_size", | |||||
type=int, | |||||
default=128, | |||||
help="Number of samples each batch for training. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--eval_batch_size", | |||||
type=int, | |||||
default=128, | |||||
help="Number of samples each batch for evaluation. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--class_num", | |||||
type=int, | |||||
default=10, | |||||
help="The number of categories. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--epochs", | |||||
type=int, | |||||
default=10, | |||||
help="The number of training epochs. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_lr", | |||||
type=float, | |||||
default=0.02, | |||||
help="The initial learning rate. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--is_cuda", | |||||
type=distutils.util.strtobool, | |||||
default=True, | |||||
help="Specify the device type. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--load_checkpoint", | |||||
type=distutils.util.strtobool, | |||||
default=False, | |||||
help="Whether to load checkpoint. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--log_every", | |||||
type=int, | |||||
default=50, | |||||
help="How many steps to log. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--eval_every_epochs", | |||||
type=int, | |||||
default=1, | |||||
help="How many epochs to eval. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_grad_bound", | |||||
type=float, | |||||
default=5.0, | |||||
help="The threshold for gradient clipping. (default: %(default)s)") # | |||||
parser.add_argument( | |||||
"--child_lr_decay_scheme", | |||||
type=str, | |||||
default="cosine", | |||||
help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove | |||||
parser.add_argument( | |||||
"--child_lr_T_0", | |||||
type=int, | |||||
default=10, | |||||
help="The length of one cycle. (default: %(default)s)") # todo: use for | |||||
parser.add_argument( | |||||
"--child_lr_T_mul", | |||||
type=int, | |||||
default=2, | |||||
help="The multiplication factor per cycle. (default: %(default)s)") # todo: use for | |||||
parser.add_argument( | |||||
"--child_l2_reg", | |||||
type=float, | |||||
default=3e-6, | |||||
help="Weight decay factor. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_lr_max", | |||||
type=float, | |||||
default=0.002, | |||||
help="The max learning rate. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_lr_min", | |||||
type=float, | |||||
default=0.001, | |||||
help="The min learning rate. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--multi_path", | |||||
type=distutils.util.strtobool, | |||||
default=False, | |||||
help="Search for multiple path in the architecture. (default: %(default)s)") # todo: use for | |||||
parser.add_argument( | |||||
"--is_mask", | |||||
type=distutils.util.strtobool, | |||||
default=True, | |||||
help="Apply mask. (default: %(default)s)") | |||||
global FLAGS | |||||
FLAGS = parser.parse_args() | |||||
def print_user_flags(FLAGS, line_limit=80): | |||||
log_strings = "\n" + "-" * line_limit + "\n" | |||||
for flag_name in sorted(vars(FLAGS)): | |||||
value = "{}".format(getattr(FLAGS, flag_name)) | |||||
log_string = flag_name | |||||
log_string += "." * (line_limit - len(flag_name) - len(value)) | |||||
log_string += value | |||||
log_strings = log_strings + log_string | |||||
log_strings = log_strings + "\n" | |||||
log_strings += "-" * line_limit | |||||
logger.info(log_strings) | |||||
def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None): | |||||
if eval_set == "test": | |||||
assert test_dataloader is not None | |||||
dataloader = test_dataloader | |||||
elif eval_set == "valid": | |||||
assert valid_dataloader is not None | |||||
dataloader = valid_dataloader | |||||
else: | |||||
raise NotImplementedError("Unknown eval_set '{}'".format(eval_set)) | |||||
tot_acc = 0 | |||||
tot = 0 | |||||
losses = [] | |||||
with torch.no_grad(): # save memory | |||||
for batch in dataloader: | |||||
x, y = batch | |||||
x, y = to_device(x, device), to_device(y, device) | |||||
logits = child_model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, aux_logits = logits | |||||
aux_loss = criterion(aux_logits, y) | |||||
else: | |||||
aux_loss = 0. | |||||
loss = criterion(logits, y) | |||||
loss = loss + aux_weight * aux_loss | |||||
# loss = loss.mean() | |||||
preds = logits.argmax(dim=1).long() | |||||
acc = torch.eq(preds, y.long()).long().sum().item() | |||||
losses.append(loss) | |||||
tot_acc += acc | |||||
tot += len(y) | |||||
losses = torch.tensor(losses) | |||||
loss = losses.mean() | |||||
if tot > 0: | |||||
final_acc = float(tot_acc) / tot | |||||
else: | |||||
final_acc = 0 | |||||
logger.info("Error in calculating final_acc") | |||||
return final_acc, loss | |||||
# TODO: learning rate scheduler | |||||
def update_lr( | |||||
optimizer, | |||||
epoch, | |||||
l2_reg=1e-4, | |||||
lr_warmup_val=None, | |||||
lr_init=0.1, | |||||
lr_decay_scheme="cosine", | |||||
lr_max=0.002, | |||||
lr_min=0.000000001, | |||||
lr_T_0=4, | |||||
lr_T_mul=1, | |||||
sync_replicas=False, | |||||
num_aggregate=None, | |||||
num_replicas=None): | |||||
if lr_decay_scheme == "cosine": | |||||
assert lr_max is not None, "Need lr_max to use lr_cosine" | |||||
assert lr_min is not None, "Need lr_min to use lr_cosine" | |||||
assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine" | |||||
assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine" | |||||
T_i = lr_T_0 | |||||
t_epoch = epoch | |||||
last_reset = 0 | |||||
while True: | |||||
t_epoch -= T_i | |||||
if t_epoch < 0: | |||||
break | |||||
last_reset += T_i | |||||
T_i *= lr_T_mul | |||||
T_curr = epoch - last_reset | |||||
def _update(): | |||||
rate = T_curr / T_i * 3.1415926 | |||||
lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate)) | |||||
return lr | |||||
learning_rate = _update() | |||||
else: | |||||
raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme)) | |||||
#update lr in optimizer | |||||
for params_group in optimizer.param_groups: | |||||
params_group['lr'] = learning_rate | |||||
return learning_rate | |||||
def train(device, output_dir='./output'): | |||||
workers = 4 | |||||
data = 'cifar10' | |||||
data_dir = FLAGS.data_dir | |||||
output_dir = FLAGS.output_dir | |||||
checkpoint_dir = FLAGS.best_checkpoint_dir | |||||
batch_size = FLAGS.batch_size | |||||
eval_batch_size = FLAGS.eval_batch_size | |||||
class_num = FLAGS.class_num | |||||
epochs = FLAGS.epochs | |||||
child_lr = FLAGS.child_lr | |||||
is_cuda = FLAGS.is_cuda | |||||
load_checkpoint = FLAGS.load_checkpoint | |||||
log_every = FLAGS.log_every | |||||
eval_every_epochs = FLAGS.eval_every_epochs | |||||
child_grad_bound = FLAGS.child_grad_bound | |||||
child_l2_reg = FLAGS.child_l2_reg | |||||
logger.info("Build dataloader") | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10") | |||||
n_train = len(dataset_train) | |||||
split = n_train // 10 | |||||
indices = list(range(n_train)) | |||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) | |||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) | |||||
train_dataloader = torch.utils.data.DataLoader(dataset_train, | |||||
batch_size=batch_size, | |||||
sampler=train_sampler, | |||||
num_workers=workers) | |||||
valid_dataloader = torch.utils.data.DataLoader(dataset_train, | |||||
batch_size=batch_size, | |||||
sampler=valid_sampler, | |||||
num_workers=workers) | |||||
test_dataloader = torch.utils.data.DataLoader(dataset_valid, | |||||
batch_size=batch_size, | |||||
num_workers=workers) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.SGD(child_model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4, nesterov=True) | |||||
# optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) | |||||
# TODO | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.001) | |||||
# move model to CPU/GPU device | |||||
child_model.to(device) | |||||
criterion.to(device) | |||||
logger.info('Start training') | |||||
start_time = time.time() | |||||
step = 0 | |||||
# save path | |||||
if not os.path.exists(output_dir): | |||||
os.mkdir(output_dir) | |||||
# model_save_path = os.path.join(output_dir, "model.pth") | |||||
# best_model_save_path = os.path.join(output_dir, "best_model.pth") | |||||
best_acc = 0 | |||||
start_epoch = 0 | |||||
# TODO: load checkpoints | |||||
# train | |||||
for epoch in range(start_epoch, epochs): | |||||
lr = update_lr(optimizer, | |||||
epoch, | |||||
l2_reg= 1e-4, | |||||
lr_warmup_val=None, | |||||
lr_init=FLAGS.child_lr, | |||||
lr_decay_scheme=FLAGS.child_lr_decay_scheme, | |||||
lr_max=0.05, | |||||
lr_min=0.001, | |||||
lr_T_0=10, | |||||
lr_T_mul=2) | |||||
child_model.train() | |||||
for batch in train_dataloader: | |||||
step += 1 | |||||
x, y = batch | |||||
x, y = to_device(x, device), to_device(y, device) | |||||
logits = child_model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, aux_logits = logits | |||||
aux_loss = criterion(aux_logits, y) | |||||
else: | |||||
aux_loss = 0. | |||||
acc = accuracy(logits, y) | |||||
loss = criterion(logits, y) | |||||
loss = loss + aux_weight * aux_loss | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
grad_norm = 0 | |||||
trainable_params = child_model.parameters() | |||||
for param in trainable_params: | |||||
nn.utils.clip_grad_norm_(param, child_grad_bound) # clip grad | |||||
optimizer.step() | |||||
if step % log_every == 0: | |||||
curr_time = time.time() | |||||
log_string = "" | |||||
log_string += "epoch={:<6d}".format(epoch) | |||||
log_string += "ch_step={:<6d}".format(step) | |||||
log_string += " loss={:<8.6f}".format(loss) | |||||
log_string += " lr={:<8.4f}".format(lr) | |||||
log_string += " |g|={:<8.4f}".format(grad_norm) | |||||
log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) | |||||
log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60) | |||||
logger.info(log_string) | |||||
epoch += 1 | |||||
save_state = { | |||||
'step': step, | |||||
'epoch': epoch, | |||||
'child_model_state_dict': child_model.state_dict(), | |||||
'optimizer_state_dict': optimizer.state_dict()} | |||||
# print(' Epoch {:<3d} loss: {:<.2f} '.format(epoch, loss)) | |||||
# torch.save(save_state, model_save_path) | |||||
child_model.eval() | |||||
logger.info("Epoch {}: Eval".format(epoch)) | |||||
eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader) | |||||
logger.info( | |||||
"ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss)) | |||||
if eval_acc > best_acc: | |||||
best_acc = eval_acc | |||||
logger.info("Save best model") | |||||
# save_state = { | |||||
# 'step': step, | |||||
# 'epoch': epoch, | |||||
# 'child_model_state_dict': child_model.state_dict(), | |||||
# 'optimizer_state_dict': optimizer.state_dict()} | |||||
# torch.save(save_state, best_model_save_path) | |||||
save_best_checkpoint(checkpoint_dir, child_model, optimizer, epoch) | |||||
result['accuracy'].append('Epoch {} acc: {:<6.4f}'.format(epoch, eval_acc,)) | |||||
acc_l.append(eval_acc) | |||||
print(result['accuracy'][-1]) | |||||
print('max acc %.4f at epoch: %i'%(max(acc_l), np.argmax(np.array(acc_l)))) | |||||
print('Time cost: %.4f hours'%( float(time.time() - start_time) /3600. )) | |||||
return result | |||||
# macro = True | |||||
parse_args() | |||||
child_fixed_arc = FLAGS.selected_space_path # './macro_seletced_space' | |||||
search_for = FLAGS.search_for | |||||
# 设置随机种子 | |||||
torch.manual_seed(FLAGS.trial_id) | |||||
torch.cuda.manual_seed_all(FLAGS.trial_id) | |||||
np.random.seed(FLAGS.trial_id) | |||||
random.seed(FLAGS.trial_id) | |||||
aux_weight = 0.4 | |||||
result = {'accuracy':[]} | |||||
acc_l = [] | |||||
# decode human readable search space to model | |||||
def convert_selected_space_format(): | |||||
# with open('./macro_selected_space.json') as js: | |||||
with open(child_fixed_arc) as js: | |||||
selected_space = json.load(js) | |||||
ops = selected_space['op_list'] | |||||
selected_space.pop('op_list') | |||||
new_selected_space = {} | |||||
for key, value in selected_space.items(): | |||||
# for macro | |||||
if FLAGS.search_for == 'macro': | |||||
new_key = key.split('_')[-1] | |||||
# for micro | |||||
elif FLAGS.search_for == 'micro': | |||||
new_key = key | |||||
if len(value) > 1 or len(value)==0: | |||||
new_value = value | |||||
elif len(value) > 0 and value[0] in ops: | |||||
new_value = ops.index(value[0]) | |||||
else: | |||||
new_value = value[0] | |||||
new_selected_space[new_key] = new_value | |||||
return new_selected_space | |||||
fixed_arc = convert_selected_space_format() | |||||
# TODO : macro search or micro search | |||||
if FLAGS.search_for == 'macro': | |||||
child_model = GeneralNetwork() | |||||
elif FLAGS.search_for == 'micro': | |||||
child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) | |||||
apply_fixed_architecture(child_model,fixed_arc) | |||||
def dump_global_result(res_path,global_result, sort_keys = False): | |||||
with open(res_path, "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) | |||||
def main(): | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '4' | |||||
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |||||
device = torch.device("cuda" if FLAGS.is_cuda else "cpu") | |||||
train(device) | |||||
dump_global_result('result_retrain.json', result['accuracy']) | |||||
if __name__ == "__main__": | |||||
main() | |||||
@@ -0,0 +1,495 @@ | |||||
import sys | |||||
from utils import accuracy | |||||
import torch | |||||
from torch import nn | |||||
from torch.utils.data import DataLoader | |||||
import datasets | |||||
import time | |||||
import logging | |||||
import os | |||||
import argparse | |||||
import distutils.util | |||||
import numpy as np | |||||
import json | |||||
import random | |||||
sys.path.append('..'+ '/' + '..') | |||||
# import custom packages | |||||
from macro import GeneralNetwork | |||||
from micro import MicroNetwork | |||||
from pytorch.fixed import apply_fixed_architecture | |||||
from pytorch.retrainer import Retrainer | |||||
from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint, mkdirs | |||||
class EnasRetrainer(Retrainer): | |||||
""" | |||||
ENAS retrainer. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to be trained. | |||||
data_dir : dataset path | |||||
The path of the dataset. | |||||
best_checkpoint_dir: 'best_checkpoint.pth' | |||||
The directory for saving model. | |||||
batch_size : int | |||||
Batch size. | |||||
eval_batch_size : int | |||||
Batch size. | |||||
num_epochs : int | |||||
Number of epochs planned for training. | |||||
lr : float | |||||
Learning rate. | |||||
is_cuda: Boolean | |||||
Whether to use GPU for training. | |||||
log_every : int | |||||
Step count per logging. | |||||
child_grad_bound : float | |||||
Gradient bound. | |||||
child_l2_reg: float | |||||
L2 regression. | |||||
eval_every_epochs: int | |||||
Evaluate every epochs. | |||||
logger: | |||||
logging. | |||||
workers : int | |||||
Workers for data loading. | |||||
device : torch.device | |||||
``torch.device("cpu")`` or ``torch.device("cuda")``. | |||||
aux_weight : float | |||||
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. | |||||
""" | |||||
def __init__(self,model,data_dir = './data',best_checkpoint_dir = './best_checkpoint', | |||||
batch_size = 1024, eval_batch_size = 1024,num_epochs = 2,lr = 0.02,is_cuda = 'True', | |||||
log_every = 40,child_grad_bound = 0.5, child_l2_reg=3e-6, eval_every_epochs=2, | |||||
logger = logging.getLogger("enas-retrain"), result_path='./'): | |||||
self.aux_weight = 0.4 | |||||
self.device = torch.device("cuda:0" ) | |||||
self.workers = 4 | |||||
self.child_model = model | |||||
self.data_dir = data_dir | |||||
self.best_checkpoint_dir = best_checkpoint_dir | |||||
self.batch_size = batch_size | |||||
self.eval_batch_size = eval_batch_size | |||||
self.num_epochs = num_epochs | |||||
self.lr = lr | |||||
self.is_cuda = is_cuda | |||||
self.log_every = log_every | |||||
self.child_grad_bound = child_grad_bound | |||||
self.child_l2_reg = child_l2_reg | |||||
self.eval_every_epochs = eval_every_epochs | |||||
self.logger = logger | |||||
self.optimizer = torch.optim.SGD(self.child_model.parameters(), self.lr, momentum=0.9, weight_decay=1.0E-4, nesterov=True) | |||||
self.criterion = nn.CrossEntropyLoss() | |||||
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs, eta_min=0.001) | |||||
# load dataset | |||||
self.init_dataloader() | |||||
self.child_model.to(self.device) | |||||
self.result_path = result_path | |||||
with open(self.result_path, "w") as file: | |||||
file.write('') | |||||
def train(self): | |||||
""" | |||||
Train ``num_epochs``. | |||||
Trigger callbacks at the start and the end of each epoch. | |||||
Parameters | |||||
---------- | |||||
validate : bool | |||||
If ``true``, will do validation every epoch. | |||||
""" | |||||
self.logger.info('** Start training **') | |||||
self.start_time = time.time() | |||||
for epoch in range(self.num_epochs): | |||||
self.train_one_epoch(epoch) | |||||
self.child_model.eval() | |||||
# if epoch / self.eval_every_epochs == 0: | |||||
self.logger.info("Epoch {}: Eval".format(epoch)) | |||||
self.validate_one_epoch(epoch) | |||||
self.lr_scheduler.step() | |||||
# print('** saving model **') | |||||
self.logger.info("** Save best model **") | |||||
# save_state = { | |||||
# 'epoch': epoch, | |||||
# 'child_model_state_dict': self.child_model.state_dict(), | |||||
# 'optimizer_state_dict': self.optimizer.state_dict()} | |||||
# torch.save(save_state, self.best_checkpoint_dir) | |||||
save_best_checkpoint(self.best_checkpoint_dir, self.child_model, self.optimizer, epoch) | |||||
def validate(self): | |||||
""" | |||||
Do one validation. Validate one epoch. | |||||
""" | |||||
pass | |||||
def export(self, file): | |||||
""" | |||||
dump the architecture to ``file``. | |||||
Parameters | |||||
---------- | |||||
file : str | |||||
File path to export to. Expected to be a JSON. | |||||
""" | |||||
pass | |||||
def checkpoint(self): | |||||
""" | |||||
Override to dump a checkpoint. | |||||
""" | |||||
pass | |||||
def init_dataloader(self): | |||||
self.logger.info("Build dataloader") | |||||
self.dataset_train, self.dataset_valid = datasets.get_dataset("cifar10", self.data_dir) | |||||
n_train = len(self.dataset_train) | |||||
split = n_train // 10 | |||||
indices = list(range(n_train)) | |||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) | |||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) | |||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=self.batch_size, | |||||
sampler=train_sampler, | |||||
num_workers=self.workers) | |||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=self.eval_batch_size, | |||||
sampler=valid_sampler, | |||||
num_workers=self.workers) | |||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, | |||||
batch_size=self.batch_size, | |||||
num_workers=self.workers) | |||||
# self.train_loader = cycle(self.train_loader) | |||||
# self.valid_loader = cycle(self.valid_loader) | |||||
def train_one_epoch(self,epoch): | |||||
""" | |||||
Train one epoch. | |||||
Parameters | |||||
---------- | |||||
epoch : int | |||||
Epoch number starting from 0. | |||||
""" | |||||
tot_acc = 0 | |||||
tot = 0 | |||||
losses = [] | |||||
step = 0 | |||||
self.child_model.train() | |||||
meters = AverageMeterGroup() | |||||
for batch in self.train_loader: | |||||
step += 1 | |||||
x, y = batch | |||||
x, y = to_device(x, self.device), to_device(y, self.device) | |||||
logits = self.child_model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, aux_logits = logits | |||||
aux_loss = self.criterion(aux_logits, y) | |||||
else: | |||||
aux_loss = 0. | |||||
acc = accuracy(logits, y) | |||||
loss = self.criterion(logits, y) | |||||
loss = loss + self.aux_weight * aux_loss | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
grad_norm = 0 | |||||
trainable_params = self.child_model.parameters() | |||||
# assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients." | |||||
# # compute the gradient norm value | |||||
# grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999) | |||||
# for param in trainable_params: | |||||
# nn.utils.clip_grad_norm_(param, self.child_grad_bound) # clip grad | |||||
# print(param_ == param) | |||||
if self.child_grad_bound is not None: | |||||
grad_norm = nn.utils.clip_grad_norm_(trainable_params, self.child_grad_bound) | |||||
trainable_params = grad_norm | |||||
self.optimizer.step() | |||||
tot_acc += acc['acc1'] | |||||
tot += 1 | |||||
losses.append(loss) | |||||
acc["loss"] = loss.item() | |||||
meters.update(acc) | |||||
if step % self.log_every == 0: | |||||
curr_time = time.time() | |||||
log_string = "" | |||||
log_string += "epoch={:<6d}".format(epoch) | |||||
log_string += "ch_step={:<6d}".format(step) | |||||
log_string += " loss={:<8.6f}".format(loss) | |||||
log_string += " lr={:<8.4f}".format(self.optimizer.param_groups[0]['lr']) | |||||
log_string += " |g|={:<8.4f}".format(grad_norm) | |||||
log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) | |||||
log_string += " mins={:<10.2f}".format(float(curr_time - self.start_time) / 60) | |||||
self.logger.info(log_string) | |||||
print("Model Epoch [%d/%d] %.3f mins %s \n " % (epoch + 1, | |||||
self.num_epochs, float(time.time() - self.start_time) / 60, meters )) | |||||
final_acc = float(tot_acc) / tot | |||||
losses = torch.tensor(losses) | |||||
loss = losses.mean() | |||||
def validate_one_epoch(self,epoch): | |||||
tot_acc = 0 | |||||
tot = 0 | |||||
losses = [] | |||||
meters = AverageMeterGroup() | |||||
with torch.no_grad(): # save memory | |||||
meters = AverageMeterGroup() | |||||
for batch in self.valid_loader: | |||||
x, y = batch | |||||
x, y = to_device(x, self.device), to_device(y, self.device) | |||||
logits = self.child_model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, aux_logits = logits | |||||
aux_loss = self.criterion(aux_logits, y) | |||||
else: | |||||
aux_loss = 0. | |||||
loss = self.criterion(logits, y) | |||||
loss = loss + self.aux_weight * aux_loss | |||||
# loss = loss.mean() | |||||
preds = logits.argmax(dim=1).long() | |||||
acc = torch.eq(preds, y.long()).long().sum().item() | |||||
acc_v = accuracy(logits, y) | |||||
losses.append(loss) | |||||
tot_acc += acc | |||||
tot += len(y) | |||||
acc_v["loss"] = loss.item() | |||||
meters.update(acc_v) | |||||
losses = torch.tensor(losses) | |||||
loss = losses.mean() | |||||
if tot > 0: | |||||
final_acc = float(tot_acc) / tot | |||||
else: | |||||
final_acc = 0 | |||||
self.logger.info("Error in calculating final_acc") | |||||
with open(self.result_path, "a") as file: | |||||
file.write( | |||||
str({"type": "Accuracy", | |||||
"result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) + '\n') | |||||
# print("Model eval %.3fmins %s \n " % ( | |||||
# float(time.time() - self.start_time) / 60, meters )) | |||||
print({"type": "Accuracy", | |||||
"result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) | |||||
self.logger.info( | |||||
"ch_step= {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format( "test", final_acc, "test", loss)) | |||||
logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', | |||||
level=logging.INFO, | |||||
filename='./retrain.log', | |||||
filemode='a') | |||||
logger = logging.getLogger("enas-retrain") | |||||
def parse_args(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument( | |||||
"--data_dir", | |||||
type=str, | |||||
default="./data", | |||||
help="Directory containing the dataset and embedding file. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--model_selected_space_path", | |||||
type=str, | |||||
default="./model_selected_space.json", | |||||
# required=True, | |||||
help="Architecture json file. (default: %(default)s)") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="res directory") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search_space directory") | |||||
parser.add_argument("--log_path", type=str, default='output/log') | |||||
parser.add_argument( | |||||
"--best_selected_space_path", | |||||
type=str, | |||||
default="./best_selected_space.json", | |||||
# required=True, | |||||
help="Best architecture selected json file by experiment. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--best_checkpoint_dir", | |||||
type=str, | |||||
default="best_checkpoint", | |||||
help="Path for saved checkpoints. (default: %(default)s)") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--search_for", | |||||
choices=["macro", "micro"], | |||||
default="macro") | |||||
parser.add_argument( | |||||
"--batch_size", | |||||
type=int, | |||||
default=128, | |||||
help="Number of samples each batch for training. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--eval_batch_size", | |||||
type=int, | |||||
default=128, | |||||
help="Number of samples each batch for evaluation. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--epochs", | |||||
type=int, | |||||
default=10, | |||||
help="The number of training epochs. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--lr", | |||||
type=float, | |||||
default=0.02, | |||||
help="The initial learning rate. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--is_cuda", | |||||
type=distutils.util.strtobool, | |||||
default=True, | |||||
help="Specify the device type. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--load_checkpoint", | |||||
type=distutils.util.strtobool, | |||||
default=False, | |||||
help="Whether to load checkpoint. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--log_every", | |||||
type=int, | |||||
default=50, | |||||
help="How many steps to log. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--eval_every_epochs", | |||||
type=int, | |||||
default=1, | |||||
help="How many epochs to eval. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_grad_bound", | |||||
type=float, | |||||
default=5.0, | |||||
help="The threshold for gradient clipping. (default: %(default)s)") # | |||||
parser.add_argument( | |||||
"--child_l2_reg", | |||||
type=float, | |||||
default=3e-6, | |||||
help="Weight decay factor. (default: %(default)s)") | |||||
parser.add_argument( | |||||
"--child_lr_decay_scheme", | |||||
type=str, | |||||
default="cosine", | |||||
help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove | |||||
global FLAGS | |||||
FLAGS = parser.parse_args() | |||||
# decode human readable search space to model | |||||
def convert_selected_space_format(child_fixed_arc): | |||||
# with open('./macro_selected_space.json') as js: | |||||
with open(child_fixed_arc) as js: | |||||
selected_space = json.load(js) | |||||
ops = selected_space['op_list'] | |||||
selected_space.pop('op_list') | |||||
new_selected_space = {} | |||||
for key, value in selected_space.items(): | |||||
# for macro | |||||
if FLAGS.search_for == 'macro': | |||||
new_key = key.split('_')[-1] | |||||
# for micro | |||||
elif FLAGS.search_for == 'micro': | |||||
new_key = key | |||||
if len(value) > 1 or len(value)==0: | |||||
new_value = value | |||||
elif len(value) > 0 and value[0] in ops: | |||||
new_value = ops.index(value[0]) | |||||
else: | |||||
new_value = value[0] | |||||
new_selected_space[new_key] = new_value | |||||
return new_selected_space | |||||
def set_random_seed(seed): | |||||
logger.info("set random seed for data reading: {}".format(seed)) | |||||
random.seed(seed) | |||||
os.environ['PYTHONHASHSEED'] = str(seed) | |||||
np.random.seed(seed) | |||||
random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
if FLAGS.is_cuda: | |||||
torch.cuda.manual_seed_all(seed) | |||||
torch.backends.cudnn.deterministic = True | |||||
def main(): | |||||
parse_args() | |||||
child_fixed_arc = FLAGS.best_selected_space_path # './macro_seletced_space' | |||||
search_for = FLAGS.search_for | |||||
# set seed to result todo: trial ID | |||||
set_random_seed(FLAGS.trial_id) | |||||
mkdirs(FLAGS.result_path, FLAGS.log_path, FLAGS.best_checkpoint_dir) | |||||
# define and load model | |||||
logger.info('** ' + FLAGS.search_for + 'search **') | |||||
fixed_arc = convert_selected_space_format(child_fixed_arc) | |||||
# Model, macro search or micro search | |||||
if FLAGS.search_for == 'macro': | |||||
child_model = GeneralNetwork() | |||||
elif FLAGS.search_for == 'micro': | |||||
child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) | |||||
apply_fixed_architecture(child_model, fixed_arc) | |||||
# load model | |||||
if FLAGS.load_checkpoint: | |||||
print('** Load model **') | |||||
logger.info('** Load model **') | |||||
child_model.load_state_dict(torch.load(FLAGS.best_checkpoint_dir)['child_model_state_dict']) | |||||
retrainer = EnasRetrainer(model=child_model, | |||||
data_dir = FLAGS.data_dir, | |||||
best_checkpoint_dir=FLAGS.best_checkpoint_dir, | |||||
batch_size=FLAGS.batch_size, | |||||
eval_batch_size=FLAGS.eval_batch_size, | |||||
num_epochs=FLAGS.epochs, | |||||
lr=FLAGS.lr, | |||||
is_cuda=FLAGS.is_cuda, | |||||
log_every=FLAGS.log_every, | |||||
child_grad_bound=FLAGS.child_grad_bound, | |||||
child_l2_reg=FLAGS.child_l2_reg, | |||||
eval_every_epochs=FLAGS.eval_every_epochs, | |||||
logger=logger, | |||||
result_path=FLAGS.result_path, | |||||
) | |||||
t1 = time.time() | |||||
retrainer.train() | |||||
print('cost time for retrain: ' , time.time() - t1) | |||||
if __name__ == "__main__": | |||||
main() |
@@ -0,0 +1,135 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import logging | |||||
import time | |||||
from argparse import ArgumentParser | |||||
import torch | |||||
import torch.nn as nn | |||||
import datasets | |||||
from macro import GeneralNetwork | |||||
from micro import MicroNetwork | |||||
from trainer import EnasTrainer | |||||
from mutator import EnasMutator | |||||
from pytorch.callbacks import (ArchitectureCheckpoint, | |||||
LRSchedulerCallback) | |||||
from utils import accuracy, reward_accuracy | |||||
from collections import OrderedDict | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
import json | |||||
torch.cuda.set_device(4) | |||||
logger = logging.getLogger('tadl-enas') | |||||
# save search space as search_space.json | |||||
def save_nas_search_space(mutator,file_path): | |||||
result = OrderedDict() | |||||
cur_layer_idx = None | |||||
for mutable in mutator.mutables.traverse(): | |||||
if not isinstance(mutable,(LayerChoice, InputChoice)): | |||||
cur_layer_idx = mutable.key + '_' | |||||
continue | |||||
# macro | |||||
if 'layer' in cur_layer_idx: | |||||
if isinstance(mutable, LayerChoice): | |||||
if 'op_list' not in result: | |||||
result['op_list'] = [str(i) for i in mutable] | |||||
result[cur_layer_idx + mutable.key] = 'op_list' | |||||
else: | |||||
result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, | |||||
'n_chosen': mutable.n_chosen if mutable.n_chosen else '', | |||||
'choose_from': mutable.choose_from if mutable.choose_from else ''} | |||||
# micro | |||||
elif 'node' in cur_layer_idx: | |||||
if isinstance(mutable,LayerChoice): | |||||
if 'op_list' not in result: | |||||
result['op_list'] = [str(i) for i in mutable] | |||||
result[mutable.key] = 'op_list' | |||||
else: | |||||
result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, | |||||
'n_chosen': mutable.n_chosen if mutable.n_chosen else '', | |||||
'choose_from': mutable.choose_from if mutable.choose_from else ''} | |||||
dump_global_result(file_path,result) | |||||
# def dump_global_result(args,global_result): | |||||
# with open(args['result_path'], "w") as ss_file: | |||||
# json.dump(global_result, ss_file, sort_keys=True, indent=2) | |||||
def dump_global_result(res_path,global_result, sort_keys = False): | |||||
with open(res_path, "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("enas") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search_space directory") | |||||
parser.add_argument("--selected_space_path", type=str, | |||||
default='./selected_space.json', help="sapce_path_out directory") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="res directory") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--batch-size", default=128, type=int) | |||||
parser.add_argument("--log-frequency", default=10, type=int) | |||||
parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") | |||||
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") | |||||
args = parser.parse_args() | |||||
# 设置随机种子 | |||||
torch.manual_seed(args.trial_id) | |||||
torch.cuda.manual_seed_all(args.trial_id) | |||||
np.random.seed(args.trial_id) | |||||
random.seed(args.trial_id) | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10") | |||||
if args.search_for == "macro": | |||||
model = GeneralNetwork() | |||||
num_epochs = args.epochs or 310 | |||||
mutator = None | |||||
mutator = EnasMutator(model) | |||||
elif args.search_for == "micro": | |||||
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) | |||||
num_epochs = args.epochs or 150 | |||||
mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) | |||||
else: | |||||
raise AssertionError | |||||
# 储存整个网络结构 | |||||
# args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) | |||||
# print( args.search_space_path, args.search_for ) | |||||
save_nas_search_space(mutator, args.search_space_path) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) | |||||
trainer = EnasTrainer(model, | |||||
loss=criterion, | |||||
metrics=accuracy, | |||||
reward_function=reward_accuracy, | |||||
optimizer=optimizer, | |||||
callbacks=[LRSchedulerCallback(lr_scheduler)], | |||||
batch_size=args.batch_size, | |||||
num_epochs=num_epochs, | |||||
dataset_train=dataset_train, | |||||
dataset_valid=dataset_valid, | |||||
log_frequency=args.log_frequency, | |||||
mutator=mutator, | |||||
child_model_path='./'+args.search_for+'_child_model') | |||||
logger.info(trainer.metrics) | |||||
t1 = time.time() | |||||
trainer.train() | |||||
trainer.result["cost_time"] = time.time() - t1 | |||||
dump_global_result(args.result_path,trainer.result) | |||||
selected_model = trainer.export_child_model(selected_space = True) | |||||
dump_global_result(args.selected_space_path,selected_model) |
@@ -0,0 +1,18 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
class EnasSelector(Selector): | |||||
def __init__(self, *args, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
self.args = args | |||||
def fit(self): | |||||
""" | |||||
only one candatite, function passed | |||||
""" | |||||
pass | |||||
if __name__ == "__main__": | |||||
hpo_selector = EnasSelector() | |||||
hpo_selector.fit() |
@@ -0,0 +1,436 @@ | |||||
from itertools import cycle | |||||
import os | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import numpy as np | |||||
import random | |||||
import logging | |||||
import time | |||||
from argparse import ArgumentParser | |||||
from collections import OrderedDict | |||||
import json | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
# import custom libraries | |||||
import datasets | |||||
from pytorch.trainer import Trainer | |||||
from pytorch.utils import AverageMeterGroup, to_device, mkdirs | |||||
from pytorch.mutables import LayerChoice, InputChoice, MutableScope | |||||
from macro import GeneralNetwork | |||||
from micro import MicroNetwork | |||||
# from trainer import EnasTrainer | |||||
from mutator import EnasMutator | |||||
from pytorch.callbacks import (ArchitectureCheckpoint, | |||||
LRSchedulerCallback) | |||||
from utils import accuracy, reward_accuracy | |||||
torch.cuda.set_device(0) | |||||
logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', | |||||
level=logging.INFO, | |||||
filename='./train.log', | |||||
filemode='a') | |||||
logger = logging.getLogger('enas_train') | |||||
class EnasTrainer(Trainer): | |||||
""" | |||||
ENAS trainer. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
PyTorch model to be trained. | |||||
loss : callable | |||||
Receives logits and ground truth label, return a loss tensor. | |||||
metrics : callable | |||||
Receives logits and ground truth label, return a dict of metrics. | |||||
reward_function : callable | |||||
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. | |||||
optimizer : Optimizer | |||||
The optimizer used for optimizing the model. | |||||
num_epochs : int | |||||
Number of epochs planned for training. | |||||
dataset_train : Dataset | |||||
Dataset for training. Will be split for training weights and architecture weights. | |||||
dataset_valid : Dataset | |||||
Dataset for testing. | |||||
mutator : EnasMutator | |||||
Use when customizing your own mutator or a mutator with customized parameters. | |||||
batch_size : int | |||||
Batch size. | |||||
workers : int | |||||
Workers for data loading. | |||||
device : torch.device | |||||
``torch.device("cpu")`` or ``torch.device("cuda")``. | |||||
log_frequency : int | |||||
Step count per logging. | |||||
callbacks : list of Callback | |||||
list of callbacks to trigger at events. | |||||
entropy_weight : float | |||||
Weight of sample entropy loss. | |||||
skip_weight : float | |||||
Weight of skip penalty loss. | |||||
baseline_decay : float | |||||
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. | |||||
child_steps : int | |||||
How many mini-batches for model training per epoch. | |||||
mutator_lr : float | |||||
Learning rate for RL controller. | |||||
mutator_steps_aggregate : int | |||||
Number of steps that will be aggregated into one mini-batch for RL controller. | |||||
mutator_steps : int | |||||
Number of mini-batches for each epoch of RL controller learning. | |||||
aux_weight : float | |||||
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. | |||||
test_arc_per_epoch : int | |||||
How many architectures are chosen for direct test after each epoch. | |||||
""" | |||||
def __init__(self, model, loss, metrics, reward_function, | |||||
optimizer, num_epochs, dataset_train, dataset_valid, | |||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, | |||||
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, | |||||
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, | |||||
test_arc_per_epoch=1,child_model_path = './', result_path='./'): | |||||
super().__init__(model, mutator if mutator is not None else EnasMutator(model), | |||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, | |||||
batch_size, workers, device, log_frequency, callbacks) | |||||
self.reward_function = reward_function | |||||
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) | |||||
self.batch_size = batch_size | |||||
self.workers = workers | |||||
self.entropy_weight = entropy_weight | |||||
self.skip_weight = skip_weight | |||||
self.baseline_decay = baseline_decay | |||||
self.baseline = 0. | |||||
self.mutator_steps_aggregate = mutator_steps_aggregate | |||||
self.mutator_steps = mutator_steps | |||||
self.child_steps = child_steps | |||||
self.aux_weight = aux_weight | |||||
self.test_arc_per_epoch = test_arc_per_epoch | |||||
self.child_model_path = child_model_path # saving the child model | |||||
self.init_dataloader() | |||||
# self.result = {'accuracy':[], | |||||
# 'cost_time':0} | |||||
self.result_path = result_path | |||||
with open(self.result_path, "w") as file: | |||||
file.write('') | |||||
def init_dataloader(self): | |||||
n_train = len(self.dataset_train) | |||||
split = n_train // 10 | |||||
indices = list(range(n_train)) | |||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) | |||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) | |||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=self.batch_size, | |||||
sampler=train_sampler, | |||||
num_workers=self.workers) | |||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, | |||||
batch_size=self.batch_size, | |||||
sampler=valid_sampler, | |||||
num_workers=self.workers) | |||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, | |||||
batch_size=self.batch_size, | |||||
num_workers=self.workers) | |||||
self.train_loader = cycle(self.train_loader) | |||||
self.valid_loader = cycle(self.valid_loader) | |||||
def train_one_epoch(self, epoch): | |||||
# Sample model and train | |||||
self.model.train() | |||||
self.mutator.eval() | |||||
meters = AverageMeterGroup() | |||||
for step in range(1, self.child_steps + 1): | |||||
x, y = next(self.train_loader) | |||||
x, y = to_device(x, self.device), to_device(y, self.device) | |||||
self.optimizer.zero_grad() | |||||
with torch.no_grad(): | |||||
self.mutator.reset() | |||||
# self._write_graph_status() | |||||
logits = self.model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, aux_logits = logits | |||||
aux_loss = self.loss(aux_logits, y) | |||||
else: | |||||
aux_loss = 0. | |||||
metrics = self.metrics(logits, y) | |||||
loss = self.loss(logits, y) | |||||
loss = loss + self.aux_weight * aux_loss | |||||
loss.backward() | |||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) | |||||
self.optimizer.step() | |||||
metrics["loss"] = loss.item() | |||||
meters.update(metrics) | |||||
if self.log_frequency is not None and step % self.log_frequency == 0: | |||||
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, | |||||
self.num_epochs, step, self.child_steps, meters) | |||||
# Train sampler (mutator) | |||||
self.model.eval() | |||||
self.mutator.train() | |||||
meters = AverageMeterGroup() | |||||
for mutator_step in range(1, self.mutator_steps + 1): | |||||
self.mutator_optim.zero_grad() | |||||
for step in range(1, self.mutator_steps_aggregate + 1): | |||||
x, y = next(self.valid_loader) | |||||
x, y = to_device(x, self.device), to_device(y, self.device) | |||||
self.mutator.reset() | |||||
with torch.no_grad(): | |||||
logits = self.model(x) | |||||
# self._write_graph_status() | |||||
metrics = self.metrics(logits, y) | |||||
reward = self.reward_function(logits, y) | |||||
if self.entropy_weight: | |||||
reward += self.entropy_weight * self.mutator.sample_entropy.item() | |||||
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) | |||||
loss = self.mutator.sample_log_prob * (reward - self.baseline) | |||||
if self.skip_weight: | |||||
loss += self.skip_weight * self.mutator.sample_skip_penalty | |||||
metrics["reward"] = reward | |||||
metrics["loss"] = loss.item() | |||||
metrics["ent"] = self.mutator.sample_entropy.item() | |||||
metrics["log_prob"] = self.mutator.sample_log_prob.item() | |||||
metrics["baseline"] = self.baseline | |||||
metrics["skip"] = self.mutator.sample_skip_penalty | |||||
loss /= self.mutator_steps_aggregate | |||||
loss.backward() | |||||
meters.update(metrics) | |||||
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate | |||||
if self.log_frequency is not None and cur_step % self.log_frequency == 0: | |||||
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, | |||||
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, | |||||
meters) | |||||
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) | |||||
self.mutator_optim.step() | |||||
def validate_one_epoch(self, epoch): | |||||
with torch.no_grad(): | |||||
accuracy = 0 | |||||
for arc_id in range(self.test_arc_per_epoch): | |||||
meters = AverageMeterGroup() | |||||
count, acc_this_round = 0,0 | |||||
for x, y in self.test_loader: | |||||
x, y = to_device(x, self.device), to_device(y, self.device) | |||||
self.mutator.reset() | |||||
child_model = self.export_child_model() | |||||
# self._generate_child_model(epoch, | |||||
# count, | |||||
# arc_id, | |||||
# child_model, | |||||
# self.child_model_path) | |||||
logits = self.model(x) | |||||
if isinstance(logits, tuple): | |||||
logits, _ = logits | |||||
metrics = self.metrics(logits, y) | |||||
loss = self.loss(logits, y) | |||||
metrics["loss"] = loss.item() | |||||
meters.update(metrics) | |||||
count += 1 | |||||
acc_this_round += metrics['acc1'] | |||||
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", | |||||
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, | |||||
meters.summary()) | |||||
acc_this_round /= count | |||||
accuracy += acc_this_round | |||||
# logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) | |||||
print({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) | |||||
with open(self.result_path, "a") as file: | |||||
file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", | |||||
"value": meters.get_last_acc()}}) + '\n') | |||||
# self.result['accuracy'].append(accuracy / self.test_arc_per_epoch) | |||||
# export child_model | |||||
def export_child_model(self, selected_space=False): | |||||
if selected_space: | |||||
sampled = self.mutator.sample_final() | |||||
else: | |||||
sampled = self.mutator._cache | |||||
result = OrderedDict() | |||||
cur_layer_id = None | |||||
for mutable in self.mutator.mutables: | |||||
if not isinstance(mutable, (LayerChoice, InputChoice)): | |||||
cur_layer_id = mutable.key | |||||
# not supported as built-in | |||||
continue | |||||
choosed_ops_idx = self.mutator._convert_mutable_decision_to_human_readable(mutable, sampled[mutable.key]) | |||||
if not isinstance(choosed_ops_idx, list): | |||||
choosed_ops_idx = [choosed_ops_idx] | |||||
if isinstance(mutable, LayerChoice): | |||||
if 'op_list' not in result: | |||||
result['op_list'] = [str(i) for i in mutable] | |||||
choosed_ops = [str(mutable[idx]) for idx in choosed_ops_idx] | |||||
else: | |||||
choosed_ops = choosed_ops_idx | |||||
if 'node' in cur_layer_id: | |||||
result[mutable.key] = choosed_ops | |||||
else: | |||||
result[cur_layer_id + '_' + mutable.key] = choosed_ops | |||||
return result | |||||
def _generate_child_model(self, | |||||
validation_epoch, | |||||
model_idx, | |||||
validation_step, | |||||
child_model, | |||||
file_path): | |||||
# create child_models folder | |||||
# parent_path = os.path.join(file_path, 'child_model') | |||||
parent_path = file_path | |||||
if not os.path.exists(parent_path): | |||||
os.mkdir(parent_path) | |||||
# create secondary directory | |||||
secondary_path = os.path.join(parent_path, 'validation_epoch_{}'.format(validation_epoch)) | |||||
if not os.path.exists(secondary_path): | |||||
os.mkdir(secondary_path) | |||||
# create third directory | |||||
folder_path = os.path.join(secondary_path, 'validation_step_{}'.format(validation_step)) | |||||
if not os.path.exists(folder_path): | |||||
os.mkdir(folder_path) | |||||
# save sampled child_model for validation | |||||
saved_path = os.path.join(folder_path, "child_model_%02d.json" % model_idx) | |||||
with open(saved_path, "w") as ss_file: | |||||
json.dump(child_model, ss_file, indent=2) | |||||
# save search space as search_space.json | |||||
def save_nas_search_space(mutator,file_path): | |||||
result = OrderedDict() | |||||
cur_layer_idx = None | |||||
for mutable in mutator.mutables.traverse(): | |||||
if not isinstance(mutable,(LayerChoice, InputChoice)): | |||||
cur_layer_idx = mutable.key + '_' | |||||
continue | |||||
# macro | |||||
if 'layer' in cur_layer_idx: | |||||
if isinstance(mutable, LayerChoice): | |||||
if 'op_list' not in result: | |||||
result['op_list'] = [str(i) for i in mutable] | |||||
result[cur_layer_idx + mutable.key] = 'op_list' | |||||
else: | |||||
result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, | |||||
'n_chosen': mutable.n_chosen if mutable.n_chosen else '', | |||||
'choose_from': mutable.choose_from if mutable.choose_from else ''} | |||||
# micro | |||||
elif 'node' in cur_layer_idx: | |||||
if isinstance(mutable,LayerChoice): | |||||
if 'op_list' not in result: | |||||
result['op_list'] = [str(i) for i in mutable] | |||||
result[mutable.key] = 'op_list' | |||||
else: | |||||
result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, | |||||
'n_chosen': mutable.n_chosen if mutable.n_chosen else '', | |||||
'choose_from': mutable.choose_from if mutable.choose_from else ''} | |||||
dump_global_result(file_path,result) | |||||
# def dump_global_result(args,global_result): | |||||
# with open(args['result_path'], "w") as ss_file: | |||||
# json.dump(global_result, ss_file, sort_keys=True, indent=2) | |||||
def dump_global_result(res_path,global_result, sort_keys = False): | |||||
with open(res_path, "w") as ss_file: | |||||
json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("enas") | |||||
parser.add_argument( | |||||
"--data_dir", | |||||
type=str, | |||||
default="./data", | |||||
help="Directory containing the dataset and embedding file. (default: %(default)s)") | |||||
parser.add_argument("--model_selected_space_path", type=str, | |||||
default='./model_selected_space.json', help="sapce_path_out directory") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./model_result.json', help="res directory") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search_space directory") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./model_selected_space.json', help="Best sapce_path_out directory of experiment") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument('--lr', type=float, default=0.005, metavar='N', | |||||
help='learning rate') | |||||
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") | |||||
parser.add_argument("--batch_size", default=128, type=int) | |||||
parser.add_argument("--log_frequency", default=10, type=int) | |||||
parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.search_space_path, args.best_selected_space_path) | |||||
# 设置随机种子 | |||||
torch.manual_seed(args.trial_id) | |||||
torch.cuda.manual_seed_all(args.trial_id) | |||||
np.random.seed(args.trial_id) | |||||
random.seed(args.trial_id) | |||||
# use deterministic instead of nondeterministic algorithm | |||||
# make sure exact results can be reproduced everytime. | |||||
torch.backends.cudnn.deterministic = True | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", args.data_dir) | |||||
if args.search_for == "macro": | |||||
model = GeneralNetwork() | |||||
num_epochs = args.epochs or 310 | |||||
mutator = None | |||||
mutator = EnasMutator(model) | |||||
elif args.search_for == "micro": | |||||
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) | |||||
num_epochs = args.epochs or 150 | |||||
mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) | |||||
else: | |||||
raise AssertionError | |||||
# 储存整个网络结构 | |||||
# args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) | |||||
# print( args.search_space_path, args.search_for ) | |||||
save_nas_search_space(mutator, args.search_space_path) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) | |||||
trainer = EnasTrainer(model, | |||||
loss=criterion, | |||||
metrics=accuracy, | |||||
reward_function=reward_accuracy, | |||||
optimizer=optimizer, | |||||
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./"+args.search_for+"_checkpoints")], | |||||
batch_size=args.batch_size, | |||||
num_epochs=num_epochs, | |||||
dataset_train=dataset_train, | |||||
dataset_valid=dataset_valid, | |||||
log_frequency=args.log_frequency, | |||||
mutator=mutator, | |||||
child_steps=2, | |||||
mutator_steps=2, | |||||
child_model_path='./'+args.search_for+'_child_model', | |||||
result_path=args.result_path) | |||||
logger.info(trainer.metrics) | |||||
t1 = time.time() | |||||
trainer.train() | |||||
# trainer.result["cost_time"] = time.time() - t1 | |||||
# dump_global_result(args.result_path,trainer.result) | |||||
selected_model = trainer.export_child_model(selected_space = True) | |||||
dump_global_result(args.best_selected_space_path,selected_model) |
@@ -0,0 +1,30 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import torch | |||||
def accuracy(output, target, topk=(1,)): | |||||
""" Computes the precision@k for the specified values of k """ | |||||
maxk = max(topk) | |||||
batch_size = target.size(0) | |||||
_, pred = output.topk(maxk, 1, True, True) | |||||
pred = pred.t() | |||||
# one-hot case | |||||
if target.ndimension() > 1: | |||||
target = target.max(1)[1] | |||||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |||||
res = dict() | |||||
for k in topk: | |||||
correct_k = correct[:k].view(-1).float().sum(0) | |||||
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() | |||||
return res | |||||
def reward_accuracy(output, target, topk=(1,)): | |||||
batch_size = target.size(0) | |||||
_, predicted = torch.max(output.data, 1) | |||||
return (predicted == target).sum().item() / batch_size |
@@ -0,0 +1,141 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import json | |||||
import logging | |||||
from .mutables import InputChoice, LayerChoice, MutableScope | |||||
from .mutator import Mutator | |||||
from .utils import to_list | |||||
_logger = logging.getLogger(__name__) | |||||
_logger.setLevel(logging.INFO) | |||||
class FixedArchitecture(Mutator): | |||||
""" | |||||
Fixed architecture mutator that always selects a certain graph. | |||||
Parameters | |||||
---------- | |||||
model : nn.Module | |||||
A mutable network. | |||||
fixed_arc : dict | |||||
Preloaded architecture object. | |||||
strict : bool | |||||
Force everything that appears in ``fixed_arc`` to be used at least once. | |||||
""" | |||||
def __init__(self, model, fixed_arc, strict=True): | |||||
super().__init__(model) | |||||
self._fixed_arc = fixed_arc | |||||
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) | |||||
fixed_arc_keys = set(self._fixed_arc.keys()) | |||||
if fixed_arc_keys - mutable_keys: | |||||
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | |||||
if mutable_keys - fixed_arc_keys: | |||||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | |||||
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) | |||||
def _from_human_readable_architecture(self, human_arc): | |||||
# convert from an exported architecture | |||||
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. | |||||
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, | |||||
# which means {"op1": [0, ]} ir {"op1": ["conv", ]} | |||||
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} | |||||
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format. | |||||
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. | |||||
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length. | |||||
for mutable in self.mutables: | |||||
if mutable.key not in result_arc: | |||||
continue # skip silently | |||||
choice_arr = result_arc[mutable.key] | |||||
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): | |||||
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ | |||||
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): | |||||
# multihot, do nothing | |||||
continue | |||||
if isinstance(mutable, LayerChoice): | |||||
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] | |||||
choice_arr = [i in choice_arr for i in range(len(mutable))] | |||||
elif isinstance(mutable, InputChoice): | |||||
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] | |||||
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] | |||||
result_arc[mutable.key] = choice_arr | |||||
return result_arc | |||||
def sample_search(self): | |||||
""" | |||||
Always returns the fixed architecture. | |||||
""" | |||||
return self._fixed_arc | |||||
def sample_final(self): | |||||
""" | |||||
Always returns the fixed architecture. | |||||
""" | |||||
return self._fixed_arc | |||||
def replace_layer_choice(self, module=None, prefix=""): | |||||
""" | |||||
Replace layer choices with selected candidates. It's done with best effort. | |||||
In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them. | |||||
If single choice, replace the module with a normal module. | |||||
Parameters | |||||
---------- | |||||
module : nn.Module | |||||
Module to be processed. | |||||
prefix : str | |||||
Module name under global namespace. | |||||
""" | |||||
if module is None: | |||||
module = self.model | |||||
for name, mutable in module.named_children(): | |||||
global_name = (prefix + "." if prefix else "") + name | |||||
if isinstance(mutable, LayerChoice): | |||||
chosen = self._fixed_arc[mutable.key] | |||||
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: | |||||
# sum is one, max is one, there has to be an only one | |||||
# this is compatible with both integer arrays, boolean arrays and float arrays | |||||
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) | |||||
setattr(module, name, mutable[chosen.index(1)]) | |||||
else: | |||||
if mutable.return_mask: | |||||
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ | |||||
"LayerChoice will not be replaced.") | |||||
# remove unused parameters | |||||
for ch, n in zip(chosen, mutable.names): | |||||
if ch == 0 and not isinstance(ch, float): | |||||
setattr(mutable, n, None) | |||||
else: | |||||
self.replace_layer_choice(mutable, global_name) | |||||
def apply_fixed_architecture(model, fixed_arc): | |||||
""" | |||||
Load architecture from `fixed_arc` and apply to model. | |||||
Parameters | |||||
---------- | |||||
model : torch.nn.Module | |||||
Model with mutables. | |||||
fixed_arc : str or dict | |||||
Path to the JSON that stores the architecture, or dict that stores the exported architecture. | |||||
Returns | |||||
------- | |||||
FixedArchitecture | |||||
Mutator that is responsible for fixes the graph. | |||||
""" | |||||
if isinstance(fixed_arc, str): | |||||
with open(fixed_arc) as f: | |||||
fixed_arc = json.load(f) | |||||
architecture = FixedArchitecture(model, fixed_arc) | |||||
architecture.reset() | |||||
# for the convenience of parameters counting | |||||
architecture.replace_layer_choice() | |||||
return architecture |
@@ -0,0 +1,79 @@ | |||||
# -*- coding: utf-8 -*- | |||||
from datetime import datetime | |||||
from io import TextIOBase | |||||
import logging | |||||
from logging import FileHandler, Formatter, Handler, StreamHandler | |||||
from pathlib import Path | |||||
import sys | |||||
import time | |||||
from typing import Optional | |||||
time_format = '%Y-%m-%d %H:%M:%S' | |||||
formatter = Formatter( | |||||
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s', | |||||
time_format | |||||
) | |||||
def init_logger() -> None: | |||||
_setup_root_logger(StreamHandler(sys.stdout), logging.INFO) | |||||
logging.basicConfig() | |||||
def _prepare_log_dir(path: Optional[str]) -> Path: | |||||
if path is None: | |||||
return Path() | |||||
ret = Path(path) | |||||
ret.mkdir(parents=True, exist_ok=True) | |||||
return ret | |||||
def _setup_root_logger(handler: Handler, level: int) -> None: | |||||
_setup_logger('tadl', handler, level) | |||||
def _setup_logger(name: str, handler: Handler, level: int) -> None: | |||||
handler.setFormatter(formatter) | |||||
logger = logging.getLogger(name) | |||||
logger.addHandler(handler) | |||||
logger.setLevel(level) | |||||
logger.propagate = False | |||||
class _LogFileWrapper(TextIOBase): | |||||
# wrap the logger file so that anything written to it will automatically get formatted | |||||
def __init__(self, log_file: TextIOBase): | |||||
self.file: TextIOBase = log_file | |||||
self.line_buffer: Optional[str] = None | |||||
self.line_start_time: Optional[datetime] = None | |||||
def write(self, s: str) -> int: | |||||
cur_time = datetime.now() | |||||
if self.line_buffer and (cur_time - self.line_start_time).total_seconds() > 0.1: | |||||
self.flush() | |||||
if self.line_buffer: | |||||
self.line_buffer += s | |||||
else: | |||||
self.line_buffer = s | |||||
self.line_start_time = cur_time | |||||
if '\n' not in s: | |||||
return len(s) | |||||
time_str = cur_time.strftime(time_format) | |||||
lines = self.line_buffer.split('\n') | |||||
for line in lines[:-1]: | |||||
self.file.write(f'[{time_str}] PRINT {line}\n') | |||||
self.file.flush() | |||||
self.line_buffer = lines[-1] | |||||
self.line_start_time = cur_time | |||||
return len(s) | |||||
def flush(self) -> None: | |||||
if self.line_buffer: | |||||
time_str = self.line_start_time.strftime(time_format) | |||||
self.file.write(f'[{time_str}] PRINT {self.line_buffer}\n') | |||||
self.file.flush() | |||||
self.line_buffer = None | |||||
@@ -0,0 +1,340 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
import warnings | |||||
from collections import OrderedDict | |||||
import torch.nn as nn | |||||
from .utils import global_mutable_counting | |||||
logger = logging.getLogger(__name__) | |||||
logger.setLevel(logging.INFO) | |||||
class Mutable(nn.Module): | |||||
""" | |||||
Mutable is designed to function as a normal layer, with all necessary operators' weights. | |||||
States and weights of architectures should be included in mutator, instead of the layer itself. | |||||
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share | |||||
decisions among different mutables. In mutator's implementation, mutators should use the key to | |||||
distinguish different mutables. Mutables that share the same key should be "similar" to each other. | |||||
Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to | |||||
produce unique ids. | |||||
Parameters | |||||
---------- | |||||
key : str | |||||
The key of mutable. | |||||
Notes | |||||
----- | |||||
The counter is program level, but mutables are model level. In case multiple models are defined, and | |||||
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually | |||||
instead of using automatic keys. | |||||
""" | |||||
def __init__(self, key=None): | |||||
super().__init__() | |||||
if key is not None: | |||||
if not isinstance(key, str): | |||||
key = str(key) | |||||
logger.warning("Warning: key \"%s\" is not string, converted to string.", key) | |||||
self._key = key | |||||
else: | |||||
self._key = self.__class__.__name__ + str(global_mutable_counting()) | |||||
self.init_hook = self.forward_hook = None | |||||
def __deepcopy__(self, memodict=None): | |||||
raise NotImplementedError("Deep copy doesn't work for mutables.") | |||||
def __call__(self, *args, **kwargs): | |||||
self._check_built() | |||||
return super().__call__(*args, **kwargs) | |||||
def set_mutator(self, mutator): | |||||
if "mutator" in self.__dict__: | |||||
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? " | |||||
"Or did you apply multiple fixed architectures?") | |||||
self.__dict__["mutator"] = mutator | |||||
@property | |||||
def key(self): | |||||
""" | |||||
Read-only property of key. | |||||
""" | |||||
return self._key | |||||
@property | |||||
def name(self): | |||||
""" | |||||
After the search space is parsed, it will be the module name of the mutable. | |||||
""" | |||||
return self._name if hasattr(self, "_name") else "_key" | |||||
@name.setter | |||||
def name(self, name): | |||||
self._name = name | |||||
def _check_built(self): | |||||
if not hasattr(self, "mutator"): | |||||
raise ValueError( | |||||
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. " | |||||
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` " | |||||
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) | |||||
class MutableScope(Mutable): | |||||
""" | |||||
Mutable scope marks a subgraph/submodule to help mutators make better decisions. | |||||
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might | |||||
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will | |||||
look like "sub-search-space" in the scope. Scopes can be nested. | |||||
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization | |||||
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after | |||||
the forward method of the class inheriting mutable scope. | |||||
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed | |||||
to appear in the dict of choices. | |||||
Parameters | |||||
---------- | |||||
key : str | |||||
Key of mutable scope. | |||||
""" | |||||
def __init__(self, key): | |||||
super().__init__(key=key) | |||||
def __call__(self, *args, **kwargs): | |||||
try: | |||||
self._check_built() | |||||
self.mutator.enter_mutable_scope(self) | |||||
return super().__call__(*args, **kwargs) | |||||
finally: | |||||
self.mutator.exit_mutable_scope(self) | |||||
class LayerChoice(Mutable): | |||||
""" | |||||
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results. | |||||
In rare cases, it can also select zero or many. | |||||
Layer choice does not allow itself to be nested. | |||||
Parameters | |||||
---------- | |||||
op_candidates : list of nn.Module or OrderedDict | |||||
A module list to be selected from. | |||||
reduction : str | |||||
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected. | |||||
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum. | |||||
``concat`` concatenate the list at dimension 1. | |||||
return_mask : bool | |||||
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. | |||||
key : str | |||||
Key of the input choice. | |||||
Attributes | |||||
---------- | |||||
length : int | |||||
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended. | |||||
names : list of str | |||||
Names of candidates. | |||||
choices : list of Module | |||||
Deprecated. A list of all candidate modules in the layer choice module. | |||||
``list(layer_choice)`` is recommended, which will serve the same purpose. | |||||
Notes | |||||
----- | |||||
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example, | |||||
.. code-block:: python | |||||
self.op_choice = LayerChoice(OrderedDict([ | |||||
("conv3x3", nn.Conv2d(3, 16, 128)), | |||||
("conv5x5", nn.Conv2d(5, 16, 128)), | |||||
("conv7x7", nn.Conv2d(7, 16, 128)) | |||||
])) | |||||
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or | |||||
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. | |||||
""" | |||||
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): | |||||
super().__init__(key=key) | |||||
self.names = [] | |||||
if isinstance(op_candidates, OrderedDict): | |||||
for name, module in op_candidates.items(): | |||||
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ | |||||
"Please don't use a reserved name '{}' for your module.".format(name) | |||||
self.add_module(name, module) | |||||
self.names.append(name) | |||||
elif isinstance(op_candidates, list): | |||||
for i, module in enumerate(op_candidates): | |||||
self.add_module(str(i), module) | |||||
self.names.append(str(i)) | |||||
else: | |||||
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) | |||||
self.reduction = reduction | |||||
self.return_mask = return_mask | |||||
def __getitem__(self, idx): | |||||
if isinstance(idx, str): | |||||
return self._modules[idx] | |||||
return list(self)[idx] | |||||
def __setitem__(self, idx, module): | |||||
key = idx if isinstance(idx, str) else self.names[idx] | |||||
return setattr(self, key, module) | |||||
def __delitem__(self, idx): | |||||
if isinstance(idx, slice): | |||||
for key in self.names[idx]: | |||||
delattr(self, key) | |||||
else: | |||||
if isinstance(idx, str): | |||||
key, idx = idx, self.names.index(idx) | |||||
else: | |||||
key = self.names[idx] | |||||
delattr(self, key) | |||||
del self.names[idx] | |||||
@property | |||||
def length(self): | |||||
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning) | |||||
return len(self) | |||||
def __len__(self): | |||||
return len(self.names) | |||||
def __iter__(self): | |||||
return map(lambda name: self._modules[name], self.names) | |||||
@property | |||||
def choices(self): | |||||
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning) | |||||
return list(self) | |||||
def forward(self, *args, **kwargs): | |||||
""" | |||||
Returns | |||||
------- | |||||
tuple of tensors | |||||
Output and selection mask. If ``return_mask`` is ``False``, only output is returned. | |||||
""" | |||||
out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs) | |||||
if self.return_mask: | |||||
return out, mask | |||||
return out | |||||
class InputChoice(Mutable): | |||||
""" | |||||
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners, | |||||
use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to | |||||
know about ``choose_from``. | |||||
The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. | |||||
The keys are designed to be the keys of the sources. To help mutators make better decisions, | |||||
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the | |||||
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., | |||||
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a | |||||
module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed. | |||||
In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2, | |||||
output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not | |||||
"actually" the direct output of cell1. | |||||
.. code-block:: python | |||||
class Cell(MutableScope): | |||||
pass | |||||
class Net(nn.Module): | |||||
def __init__(self): | |||||
self.cell1 = Cell("cell1") | |||||
self.cell2 = Cell("cell2") | |||||
self.op = LayerChoice([conv3x3(), conv5x5()], key="op") | |||||
self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY]) | |||||
def forward(self, x): | |||||
x1 = max_pooling(self.cell1(x)) | |||||
x2 = self.cell2(x) | |||||
x3 = self.op(x) | |||||
x4 = torch.zeros_like(x) | |||||
return self.input_choice([x1, x2, x3, x4]) | |||||
Parameters | |||||
---------- | |||||
n_candidates : int | |||||
Number of inputs to choose from. | |||||
choose_from : list of str | |||||
List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled. | |||||
If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates`` | |||||
number of empty string. | |||||
n_chosen : int | |||||
Recommended inputs to choose. If None, mutator is instructed to select any. | |||||
reduction : str | |||||
``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`. | |||||
return_mask : bool | |||||
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. | |||||
key : str | |||||
Key of the input choice. | |||||
""" | |||||
NO_KEY = "" | |||||
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, | |||||
reduction="sum", return_mask=False, key=None): | |||||
super().__init__(key=key) | |||||
# precondition check | |||||
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ | |||||
"must be not None." | |||||
if choose_from is not None and n_candidates is None: | |||||
n_candidates = len(choose_from) | |||||
elif choose_from is None and n_candidates is not None: | |||||
choose_from = [self.NO_KEY] * n_candidates | |||||
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." | |||||
assert n_candidates > 0, "Number of candidates must be greater than 0." | |||||
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ | |||||
"than number of candidates." | |||||
self.n_candidates = n_candidates | |||||
self.choose_from = choose_from.copy() | |||||
self.n_chosen = n_chosen | |||||
self.reduction = reduction | |||||
self.return_mask = return_mask | |||||
def forward(self, optional_inputs): | |||||
""" | |||||
Forward method of LayerChoice. | |||||
Parameters | |||||
---------- | |||||
optional_inputs : list or dict | |||||
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of | |||||
``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as | |||||
``choose_from``. | |||||
Returns | |||||
------- | |||||
tuple of tensors | |||||
Output and selection mask. If ``return_mask`` is ``False``, only output is returned. | |||||
""" | |||||
optional_input_list = optional_inputs | |||||
if isinstance(optional_inputs, dict): | |||||
optional_input_list = [optional_inputs[tag] for tag in self.choose_from] | |||||
assert isinstance(optional_input_list, list), \ | |||||
"Optional input list must be a list, not a {}.".format(type(optional_input_list)) | |||||
assert len(optional_inputs) == self.n_candidates, \ | |||||
"Length of the input list must be equal to number of candidates." | |||||
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) | |||||
if self.return_mask: | |||||
return out, mask | |||||
return out | |||||
@@ -0,0 +1,309 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
from .base_mutator import BaseMutator | |||||
from .mutables import LayerChoice, InputChoice | |||||
from .utils import to_list | |||||
logger = logging.getLogger(__name__) | |||||
logger.setLevel(logging.INFO) | |||||
class Mutator(BaseMutator): | |||||
def __init__(self, model): | |||||
super().__init__(model) | |||||
self._cache = dict() | |||||
self._connect_all = False | |||||
def sample_search(self): | |||||
""" | |||||
Override to implement this method to iterate over mutables and make decisions. | |||||
Returns | |||||
------- | |||||
dict | |||||
A mapping from key of mutables to decisions. | |||||
""" | |||||
raise NotImplementedError | |||||
def sample_final(self): | |||||
""" | |||||
Override to implement this method to iterate over mutables and make decisions that is final | |||||
for export and retraining. | |||||
Returns | |||||
------- | |||||
dict | |||||
A mapping from key of mutables to decisions. | |||||
""" | |||||
raise NotImplementedError | |||||
def reset(self): | |||||
""" | |||||
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local | |||||
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly. | |||||
""" | |||||
self._cache = self.sample_search() | |||||
def export(self): | |||||
""" | |||||
Resample (for final) and return results. | |||||
Returns | |||||
------- | |||||
dict | |||||
A mapping from key of mutables to decisions. | |||||
""" | |||||
sampled = self.sample_final() | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if not isinstance(mutable, (LayerChoice, InputChoice)): | |||||
# not supported as built-in | |||||
continue | |||||
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key)) | |||||
if sampled: | |||||
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys())) | |||||
return result | |||||
def status(self): | |||||
""" | |||||
Return current selection status of mutator. | |||||
Returns | |||||
------- | |||||
dict | |||||
A mapping from key of mutables to decisions. All weights (boolean type and float type) | |||||
are converted into real number values. Numpy arrays and tensors are converted into list. | |||||
""" | |||||
data = dict() | |||||
for k, v in self._cache.items(): | |||||
if torch.is_tensor(v): | |||||
v = v.detach().cpu().numpy() | |||||
if isinstance(v, np.ndarray): | |||||
v = v.astype(np.float32).tolist() | |||||
data[k] = v | |||||
return data | |||||
def graph(self, inputs): | |||||
""" | |||||
Return model supernet graph. | |||||
Parameters | |||||
---------- | |||||
inputs: tuple of tensor | |||||
Inputs that will be feeded into the network. | |||||
Returns | |||||
------- | |||||
dict | |||||
Containing ``node``, in Tensorboard GraphDef format. | |||||
Additional key ``mutable`` is a map from key to list of modules. | |||||
""" | |||||
if not torch.__version__.startswith("1.4"): | |||||
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.") | |||||
from nni._graph_utils import build_graph | |||||
from google.protobuf import json_format | |||||
# protobuf should be installed as long as tensorboard is installed | |||||
try: | |||||
self._connect_all = True | |||||
graph_def, _ = build_graph(self.model, inputs, verbose=False) | |||||
result = json_format.MessageToDict(graph_def) | |||||
finally: | |||||
self._connect_all = False | |||||
# `mutable` is to map the keys to a list of corresponding modules. | |||||
# A key can be linked to multiple modules, use `dedup=False` to find them all. | |||||
result["mutable"] = defaultdict(list) | |||||
for mutable in self.mutables.traverse(deduplicate=False): | |||||
# A module will be represent in the format of | |||||
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}] | |||||
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend. | |||||
# This format is aligned with the scope name jit gives. | |||||
modules = mutable.name.split(".") | |||||
path = [ | |||||
{"type": self.model.__class__.__name__, "name": ""} | |||||
] | |||||
m = self.model | |||||
for module in modules: | |||||
m = getattr(m, module) | |||||
path.append({ | |||||
"type": m.__class__.__name__, | |||||
"name": module | |||||
}) | |||||
result["mutable"][mutable.key].append(path) | |||||
return result | |||||
def on_forward_layer_choice(self, mutable, *args, **kwargs): | |||||
""" | |||||
On default, this method retrieves the decision obtained previously, and select certain operations. | |||||
Only operations with non-zero weight will be executed. The results will be added to a list. | |||||
Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. | |||||
Parameters | |||||
---------- | |||||
mutable : LayerChoice | |||||
Layer choice module. | |||||
args : list of torch.Tensor | |||||
Inputs | |||||
kwargs : dict | |||||
Inputs | |||||
Returns | |||||
------- | |||||
tuple of torch.Tensor and torch.Tensor | |||||
Output and mask. | |||||
""" | |||||
if self._connect_all: | |||||
return self._all_connect_tensor_reduction(mutable.reduction, | |||||
[op(*args, **kwargs) for op in mutable]), \ | |||||
torch.ones(len(mutable)).bool() | |||||
def _map_fn(op, args, kwargs): | |||||
return op(*args, **kwargs) | |||||
mask = self._get_decision(mutable) | |||||
assert len(mask) == len(mutable), \ | |||||
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) | |||||
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) | |||||
return self._tensor_reduction(mutable.reduction, out), mask | |||||
def on_forward_input_choice(self, mutable, tensor_list): | |||||
""" | |||||
On default, this method retrieves the decision obtained previously, and select certain tensors. | |||||
Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`. | |||||
Parameters | |||||
---------- | |||||
mutable : InputChoice | |||||
Input choice module. | |||||
tensor_list : list of torch.Tensor | |||||
Tensor list to apply the decision on. | |||||
Returns | |||||
------- | |||||
tuple of torch.Tensor and torch.Tensor | |||||
Output and mask. | |||||
""" | |||||
if self._connect_all: | |||||
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \ | |||||
torch.ones(mutable.n_candidates).bool() | |||||
mask = self._get_decision(mutable) | |||||
assert len(mask) == mutable.n_candidates, \ | |||||
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) | |||||
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | |||||
return self._tensor_reduction(mutable.reduction, out), mask | |||||
def _select_with_mask(self, map_fn, candidates, mask): | |||||
""" | |||||
Select masked tensors and return a list of tensors. | |||||
Parameters | |||||
---------- | |||||
map_fn : function | |||||
Convert candidates to target candidates. Can be simply identity. | |||||
candidates : list of torch.Tensor | |||||
Tensor list to apply the decision on. | |||||
mask : list-like object | |||||
Can be a list, an numpy array or a tensor (recommended). Needs to | |||||
have the same length as ``candidates``. | |||||
Returns | |||||
------- | |||||
tuple of list of torch.Tensor and torch.Tensor | |||||
Output and mask. | |||||
""" | |||||
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \ | |||||
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \ | |||||
"BoolTensor" in mask.type(): | |||||
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] | |||||
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ | |||||
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \ | |||||
"FloatTensor" in mask.type(): | |||||
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] | |||||
else: | |||||
raise ValueError("Unrecognized mask '%s'" % mask) | |||||
if not torch.is_tensor(mask): | |||||
mask = torch.tensor(mask) # pylint: disable=not-callable | |||||
return out, mask | |||||
def _tensor_reduction(self, reduction_type, tensor_list): | |||||
if reduction_type == "none": | |||||
return tensor_list | |||||
if not tensor_list: | |||||
return None # empty. return None for now | |||||
if len(tensor_list) == 1: | |||||
return tensor_list[0] | |||||
if reduction_type == "sum": | |||||
return sum(tensor_list) | |||||
if reduction_type == "mean": | |||||
return sum(tensor_list) / len(tensor_list) | |||||
if reduction_type == "concat": | |||||
return torch.cat(tensor_list, dim=1) | |||||
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) | |||||
def _all_connect_tensor_reduction(self, reduction_type, tensor_list): | |||||
if reduction_type == "none": | |||||
return tensor_list | |||||
if reduction_type == "concat": | |||||
return torch.cat(tensor_list, dim=1) | |||||
return torch.stack(tensor_list).sum(0) | |||||
def _get_decision(self, mutable): | |||||
""" | |||||
By default, this method checks whether `mutable.key` is already in the decision cache, | |||||
and returns the result without double-check. | |||||
Parameters | |||||
---------- | |||||
mutable : Mutable | |||||
Returns | |||||
------- | |||||
object | |||||
""" | |||||
if mutable.key not in self._cache: | |||||
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key)) | |||||
result = self._cache[mutable.key] | |||||
logger.debug("Decision %s: %s", mutable.key, result) | |||||
return result | |||||
def _convert_mutable_decision_to_human_readable(self, mutable, sampled): | |||||
# Assert the existence of mutable.key in returned architecture. | |||||
# Also check if there is anything extra. | |||||
multihot_list = to_list(sampled) | |||||
converted = None | |||||
# If it's a boolean array, we can do optimization. | |||||
if all([t == 0 or t == 1 for t in multihot_list]): | |||||
if isinstance(mutable, LayerChoice): | |||||
assert len(multihot_list) == len(mutable), \ | |||||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | |||||
% (mutable.key, multihot_list) | |||||
# check if all modules have different names and they indeed have names | |||||
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names): | |||||
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]] | |||||
else: | |||||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | |||||
if isinstance(mutable, InputChoice): | |||||
assert len(multihot_list) == mutable.n_candidates, \ | |||||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | |||||
% (mutable.key, multihot_list) | |||||
# check if all input candidates have different names | |||||
if len(set(mutable.choose_from)) == mutable.n_candidates: | |||||
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]] | |||||
else: | |||||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | |||||
if converted is not None: | |||||
# if only one element, then remove the bracket | |||||
if len(converted) == 1: | |||||
converted = converted[0] | |||||
else: | |||||
# do nothing | |||||
converted = multihot_list | |||||
return converted |
@@ -0,0 +1,47 @@ | |||||
# Network Morphism | |||||
The implementation of the Network Morphism algorithm is based on | |||||
[Auto-Keras: An Efficient Neural Architecture Search System](https://arxiv.org/pdf/1806.10282.pdf) | |||||
Train stage | |||||
``` | |||||
python network_morphism_train.py | |||||
--trial_id 0 | |||||
--experiment_dir 'tadl' | |||||
--log_path 'tadl/train/0/log' | |||||
--data_dir '../data/' | |||||
--result_path 'trial_id/result.json' | |||||
--log_path 'trial_id/log' | |||||
--search_space_path 'experiment_id/search_space.json' | |||||
--best_selected_space_path 'experiment_id/best_selected_space.json' | |||||
--lr 0.001 --epochs 100 --batch_size 32 --opt 'SGD' | |||||
``` | |||||
select stage | |||||
``` | |||||
python network_morphism_select.py | |||||
``` | |||||
retrain stage | |||||
``` | |||||
python network_morphism_retrain.py | |||||
--data_dir '../data/' | |||||
--experiment_dir 'tadl' | |||||
--result_path 'trial_id/result.json' | |||||
--log_path 'trial_id/log' | |||||
--best_selected_space_path 'experiment_id/best_selected_space.json' | |||||
--best_checkpoint_dir 'experiment_id/' | |||||
--trial_id 0 --batch_size 32 --opt 'SGD' --epochs 100 --lr 0.001 | |||||
``` | |||||
The best model searched achieved 88.1% on CIFAR-10 dataset after 100 trials. | |||||
Dependencies: | |||||
``` | |||||
Python = 3.6.13 | |||||
pytorch = 1.8.0 | |||||
torchvision = 0.9.0 | |||||
scipy = 1.5.2 | |||||
scikit-learn = 0.24.1 | |||||
``` |
@@ -0,0 +1,517 @@ | |||||
import math | |||||
import random | |||||
from copy import deepcopy | |||||
from functools import total_ordering | |||||
from queue import PriorityQueue | |||||
import numpy as np | |||||
from scipy.linalg import LinAlgError, cho_solve, cholesky, solve_triangular | |||||
from scipy.optimize import linear_sum_assignment | |||||
from sklearn.metrics.pairwise import rbf_kernel | |||||
from .graph_transformer import transform | |||||
from .layers import is_layer | |||||
from utils import Constant, OptimizeMode | |||||
import logging | |||||
logger = logging.getLogger(__name__) | |||||
# equation(6) dl | |||||
def layer_distance(a, b): | |||||
"""The distance between two layers.""" | |||||
# pylint: disable=unidiomatic-typecheck | |||||
if not isinstance(a, type(b)): | |||||
return 1.0 | |||||
if is_layer(a, "Conv"): | |||||
att_diff = [ | |||||
(a.filters, b.filters), | |||||
(a.kernel_size, b.kernel_size), | |||||
(a.stride, b.stride), | |||||
] | |||||
return attribute_difference(att_diff) | |||||
if is_layer(a, "Pooling"): | |||||
att_diff = [ | |||||
(a.padding, b.padding), | |||||
(a.kernel_size, b.kernel_size), | |||||
(a.stride, b.stride), | |||||
] | |||||
return attribute_difference(att_diff) | |||||
return 0.0 | |||||
# equation(6) | |||||
def attribute_difference(att_diff): | |||||
''' The attribute distance. | |||||
''' | |||||
ret = 0 | |||||
for a_value, b_value in att_diff: | |||||
if max(a_value, b_value) == 0: | |||||
ret += 0 | |||||
else: | |||||
ret += abs(a_value - b_value) * 1.0 / max(a_value, b_value) | |||||
return ret * 1.0 / len(att_diff) | |||||
# equation(7) A | |||||
def layers_distance(list_a, list_b): | |||||
"""The distance between the layers of two neural networks.""" | |||||
len_a = len(list_a) | |||||
len_b = len(list_b) | |||||
f = np.zeros((len_a + 1, len_b + 1)) | |||||
f[-1][-1] = 0 | |||||
for i in range(-1, len_a): | |||||
f[i][-1] = i + 1 | |||||
for j in range(-1, len_b): | |||||
f[-1][j] = j + 1 | |||||
for i in range(len_a): | |||||
for j in range(len_b): | |||||
f[i][j] = min( | |||||
f[i][j - 1] + 1, | |||||
f[i - 1][j] + 1, | |||||
f[i - 1][j - 1] + layer_distance(list_a[i], list_b[j]), | |||||
) | |||||
return f[len_a - 1][len_b - 1] | |||||
# equation (9) ds | |||||
# 0: topo rank of the start, 1: rank of the end | |||||
def skip_connection_distance(a, b): | |||||
"""The distance between two skip-connections.""" | |||||
if a[2] != b[2]: | |||||
return 1.0 | |||||
len_a = abs(a[1] - a[0]) | |||||
len_b = abs(b[1] - b[0]) | |||||
return (abs(a[0] - b[0]) + abs(len_a - len_b)) / \ | |||||
(max(a[0], b[0]) + max(len_a, len_b)) | |||||
# equation (8) Ds | |||||
# convert equation (8) minimization part into a bipartite graph matching problem and solved by hungarian algorithm(linear_sum_assignment) | |||||
def skip_connections_distance(list_a, list_b): | |||||
"""The distance between the skip-connections of two neural networks.""" | |||||
distance_matrix = np.zeros((len(list_a), len(list_b))) | |||||
for i, a in enumerate(list_a): | |||||
for j, b in enumerate(list_b): | |||||
distance_matrix[i][j] = skip_connection_distance(a, b) | |||||
return distance_matrix[linear_sum_assignment(distance_matrix)].sum() + abs( | |||||
len(list_a) - len(list_b) | |||||
) | |||||
# equation (4) | |||||
def edit_distance(x, y): | |||||
"""The distance between two neural networks. | |||||
Args: | |||||
x: An instance of NetworkDescriptor. | |||||
y: An instance of NetworkDescriptor | |||||
Returns: | |||||
The edit-distance between x and y. | |||||
""" | |||||
ret = layers_distance(x.layers, y.layers) | |||||
ret += Constant.KERNEL_LAMBDA * skip_connections_distance( | |||||
x.skip_connections, y.skip_connections | |||||
) | |||||
return ret | |||||
class IncrementalGaussianProcess: | |||||
"""Gaussian process regressor. | |||||
Attributes: | |||||
alpha: A hyperparameter. | |||||
""" | |||||
def __init__(self): | |||||
self.alpha = 1e-10 | |||||
self._distance_matrix = None | |||||
self._x = None | |||||
self._y = None | |||||
self._first_fitted = False | |||||
self._l_matrix = None | |||||
self._alpha_vector = None | |||||
@property | |||||
def kernel_matrix(self): | |||||
''' Kernel matric. | |||||
''' | |||||
return self._distance_matrix | |||||
def fit(self, train_x, train_y): | |||||
""" Fit the regressor with more data. | |||||
Args: | |||||
train_x: A list of NetworkDescriptor. | |||||
train_y: A list of metric values. | |||||
""" | |||||
if self.first_fitted: | |||||
self.incremental_fit(train_x, train_y) | |||||
else: | |||||
self.first_fit(train_x, train_y) | |||||
# compute the kernel matrix k, alpha_vector | |||||
# 和first fit区别就是需要加入新的训练样本扩充distance matrix | |||||
def incremental_fit(self, train_x, train_y): | |||||
""" Incrementally fit the regressor. """ | |||||
if not self._first_fitted: | |||||
raise ValueError( | |||||
"The first_fit function needs to be called first.") | |||||
train_x, train_y = np.array(train_x), np.array(train_y) | |||||
# Incrementally compute K | |||||
up_right_k = edit_distance_matrix(self._x, train_x) | |||||
down_left_k = np.transpose(up_right_k) | |||||
down_right_k = edit_distance_matrix(train_x) | |||||
up_k = np.concatenate((self._distance_matrix, up_right_k), axis=1) | |||||
down_k = np.concatenate((down_left_k, down_right_k), axis=1) | |||||
temp_distance_matrix = np.concatenate((up_k, down_k), axis=0) | |||||
k_matrix = bourgain_embedding_matrix(temp_distance_matrix) | |||||
diagonal = np.diag_indices_from(k_matrix) | |||||
diagonal = (diagonal[0][-len(train_x):], diagonal[1][-len(train_x):]) | |||||
k_matrix[diagonal] += self.alpha | |||||
try: | |||||
self._l_matrix = cholesky(k_matrix, lower=True) # Line 2 | |||||
except LinAlgError as err: | |||||
logger.error('LinAlgError') | |||||
return self | |||||
self._x = np.concatenate((self._x, train_x), axis=0) | |||||
self._y = np.concatenate((self._y, train_y), axis=0) | |||||
self._distance_matrix = temp_distance_matrix | |||||
self._alpha_vector = cho_solve( | |||||
(self._l_matrix, True), self._y) # Line 3 | |||||
return self | |||||
@property | |||||
def first_fitted(self): | |||||
''' if it is firsr fitted | |||||
''' | |||||
return self._first_fitted | |||||
# update过程,第一次fit。 | |||||
def first_fit(self, train_x, train_y): | |||||
""" Fit the regressor for the first time. """ | |||||
train_x, train_y = np.array(train_x), np.array(train_y) | |||||
self._x = np.copy(train_x) | |||||
self._y = np.copy(train_y) | |||||
self._distance_matrix = edit_distance_matrix(self._x) | |||||
k_matrix = bourgain_embedding_matrix(self._distance_matrix) | |||||
k_matrix[np.diag_indices_from(k_matrix)] += self.alpha | |||||
self._l_matrix = cholesky(k_matrix, lower=True) # Line 2 | |||||
# cho_solve Ax = b return x = A^{-1}b | |||||
self._alpha_vector = cho_solve( | |||||
(self._l_matrix, True), self._y) # Line 3 | |||||
self._first_fitted = True | |||||
return self | |||||
# 获得 predictive distribution 的 mean & std | |||||
def predict(self, train_x): | |||||
"""Predict the result. | |||||
Args: | |||||
train_x: A list of NetworkDescriptor. | |||||
Returns: | |||||
y_mean: The predicted mean. | |||||
y_std: The predicted standard deviation. | |||||
""" | |||||
k_trans = np.exp(-np.power(edit_distance_matrix(train_x, self._x), 2)) | |||||
y_mean = k_trans.dot(self._alpha_vector) # Line 4 (y_mean = f_star) | |||||
# compute inverse K_inv of K based on its Cholesky | |||||
# decomposition L and its inverse L_inv | |||||
l_inv = solve_triangular( | |||||
self._l_matrix.T, np.eye( | |||||
self._l_matrix.shape[0])) | |||||
k_inv = l_inv.dot(l_inv.T) | |||||
# Compute variance of predictive distribution | |||||
y_var = np.ones(len(train_x), dtype=np.float) | |||||
y_var -= np.einsum("ij,ij->i", np.dot(k_trans, k_inv), k_trans) | |||||
# Check if any of the variances is negative because of | |||||
# numerical issues. If yes: set the variance to 0. | |||||
y_var_negative = y_var < 0 | |||||
if np.any(y_var_negative): | |||||
y_var[y_var_negative] = 0.0 | |||||
return y_mean, np.sqrt(y_var) | |||||
def edit_distance_matrix(train_x, train_y=None): | |||||
"""Calculate the edit distance. | |||||
Args: | |||||
train_x: A list of neural architectures. | |||||
train_y: A list of neural architectures. | |||||
Returns: | |||||
An edit-distance matrix. | |||||
""" | |||||
if train_y is None: | |||||
ret = np.zeros((train_x.shape[0], train_x.shape[0])) | |||||
for x_index, x in enumerate(train_x): | |||||
for y_index, y in enumerate(train_x): | |||||
if x_index == y_index: | |||||
ret[x_index][y_index] = 0 | |||||
elif x_index < y_index: | |||||
ret[x_index][y_index] = edit_distance(x, y) | |||||
else: | |||||
ret[x_index][y_index] = ret[y_index][x_index] | |||||
return ret | |||||
ret = np.zeros((train_x.shape[0], train_y.shape[0])) | |||||
for x_index, x in enumerate(train_x): | |||||
for y_index, y in enumerate(train_y): | |||||
ret[x_index][y_index] = edit_distance(x, y) | |||||
return ret | |||||
def vector_distance(a, b): | |||||
"""The Euclidean distance between two vectors.""" | |||||
a = np.array(a) | |||||
b = np.array(b) | |||||
return np.linalg.norm(a - b) | |||||
# 从edit-distance矩阵空间到欧几里得空间的映射 | |||||
def bourgain_embedding_matrix(distance_matrix): | |||||
"""Use Bourgain algorithm to embed the neural architectures based on their edit-distance. | |||||
Args: | |||||
distance_matrix: A matrix of edit-distances. | |||||
Returns: | |||||
A matrix of distances after embedding. | |||||
""" | |||||
distance_matrix = np.array(distance_matrix) | |||||
n = len(distance_matrix) | |||||
if n == 1: | |||||
return distance_matrix | |||||
np.random.seed(123) | |||||
distort_elements = [] | |||||
r = range(n) | |||||
k = int(math.ceil(math.log(n) / math.log(2) - 1)) | |||||
t = int(math.ceil(math.log(n))) | |||||
counter = 0 | |||||
for i in range(0, k + 1): | |||||
for t in range(t): | |||||
s = np.random.choice(r, 2 ** i) | |||||
for j in r: | |||||
d = min([distance_matrix[j][s] for s in s]) | |||||
counter += len(s) | |||||
if i == 0 and t == 0: | |||||
distort_elements.append([d]) | |||||
else: | |||||
distort_elements[j].append(d) | |||||
return rbf_kernel(distort_elements, distort_elements) | |||||
class BayesianOptimizer: | |||||
""" A Bayesian optimizer for neural architectures. | |||||
Attributes: | |||||
searcher: The Searcher which is calling the Bayesian optimizer. | |||||
t_min: The minimum temperature for simulated annealing. | |||||
metric: An instance of the Metric subclasses. | |||||
gpr: A GaussianProcessRegressor for bayesian optimization. | |||||
beta: The beta in acquisition function. (refer to our paper) | |||||
search_tree: The network morphism search tree. | |||||
""" | |||||
def __init__(self, searcher, t_min, optimizemode, beta=None): | |||||
self.searcher = searcher | |||||
self.t_min = t_min | |||||
self.optimizemode = optimizemode | |||||
self.gpr = IncrementalGaussianProcess() | |||||
self.beta = beta if beta is not None else Constant.BETA | |||||
self.search_tree = SearchTree() | |||||
def fit(self, x_queue, y_queue): | |||||
""" Fit the optimizer with new architectures and performances. | |||||
Args: | |||||
x_queue: A list of NetworkDescriptor. | |||||
y_queue: A list of metric values. | |||||
""" | |||||
self.gpr.fit(x_queue, y_queue) | |||||
# Algorithm 1 | |||||
# optimize acquisition function | |||||
def generate(self, descriptors): | |||||
"""Generate new architecture. | |||||
Args: | |||||
descriptors: All the searched neural architectures. (search history) | |||||
Returns: | |||||
graph: An instance of Graph. A morphed neural network with weights. | |||||
father_id: The father node ID in the search tree. | |||||
""" | |||||
model_ids = self.search_tree.adj_list.keys() | |||||
target_graph = None | |||||
father_id = None | |||||
descriptors = deepcopy(descriptors) | |||||
elem_class = Elem | |||||
if self.optimizemode is OptimizeMode.Maximize: | |||||
elem_class = ReverseElem | |||||
''' | |||||
1.初始化优先队列 | |||||
2.优先队列里面元素为之前所有生成的模型 | |||||
''' | |||||
pq = PriorityQueue() | |||||
temp_list = [] | |||||
for model_id in model_ids: | |||||
metric_value = self.searcher.get_metric_value_by_id(model_id) | |||||
temp_list.append((metric_value, model_id)) | |||||
temp_list = sorted(temp_list) | |||||
for metric_value, model_id in temp_list: | |||||
graph = self.searcher.load_model_by_id(model_id) | |||||
graph.clear_operation_history() | |||||
graph.clear_weights() | |||||
# 已经产生的模型father_id就是自己的id | |||||
pq.put(elem_class(metric_value, model_id, graph)) | |||||
t = 1.0 | |||||
t_min = self.t_min | |||||
alpha = 0.9 | |||||
opt_acq = self._get_init_opt_acq_value() | |||||
num_iter = 0 | |||||
# logger.info('initial queue size ', pq.qsize()) | |||||
while not pq.empty() and t > t_min: | |||||
num_iter += 1 | |||||
elem = pq.get() | |||||
# logger.info("elem.metric_value:{}".format(elem.metric_value)) | |||||
# logger.info("opt_acq:{}".format(opt_acq)) | |||||
if self.optimizemode is OptimizeMode.Maximize: | |||||
temp_exp = min((elem.metric_value - opt_acq) / t, 1.0) | |||||
else: | |||||
temp_exp = min((opt_acq - elem.metric_value) / t, 1.0) | |||||
# logger.info("temp_exp this round ", temp_exp) | |||||
ap = math.exp(temp_exp) | |||||
# logger.info("ap this round ", ap) | |||||
if ap >= random.uniform(0, 1): | |||||
# line 9,10 in algorithm 1 | |||||
for temp_graph in transform(elem.graph): | |||||
# 已经出现过的网络不加入 | |||||
if contain(descriptors, temp_graph.extract_descriptor()): | |||||
continue | |||||
#用acq作为贝叶斯模型给出的评价标准 | |||||
temp_acq_value = self.acq(temp_graph) | |||||
# 这个优先队列会不断增长,就算transform出来的网络也会进入。 | |||||
pq.put( | |||||
# 记住这个模型是从哪个father生长出来的 | |||||
elem_class( | |||||
temp_acq_value, | |||||
elem.father_id, | |||||
temp_graph)) | |||||
# logger.info('temp_acq_value ', temp_acq_value) | |||||
# logger.info('queue size ', pq.qsize()) | |||||
descriptors.append(temp_graph.extract_descriptor()) | |||||
# 选一个最好的当父 | |||||
if self._accept_new_acq_value(opt_acq, temp_acq_value): | |||||
opt_acq = temp_acq_value | |||||
father_id = elem.father_id | |||||
target_graph = deepcopy(temp_graph) | |||||
t *= alpha | |||||
# logger.info('number of iter in this search {}'.format(num_iter)) | |||||
# Did not found a not duplicated architecture | |||||
if father_id is None: | |||||
return None, None | |||||
nm_graph = self.searcher.load_model_by_id(father_id) | |||||
# 从当前父graph开始,根据target_graph中的operation_history,一步步从当前父网络操作到target_graph | |||||
# 因为在存入pq时进行了clear_operation_history()操作。等于target_graph中只存了从当前父网络到target_graph的操作 | |||||
# 而nm_graph中的operation_history保存完整的,到基类的history | |||||
for args in target_graph.operation_history: | |||||
getattr(nm_graph, args[0])(*list(args[1:])) | |||||
# target space | |||||
return nm_graph, father_id | |||||
# equation (10) | |||||
def acq(self, graph): | |||||
''' estimate the value of generated graph | |||||
''' | |||||
mean, std = self.gpr.predict(np.array([graph.extract_descriptor()])) | |||||
if self.optimizemode is OptimizeMode.Maximize: | |||||
return mean + self.beta * std | |||||
return mean - self.beta * std | |||||
def _get_init_opt_acq_value(self): | |||||
if self.optimizemode is OptimizeMode.Maximize: | |||||
return -np.inf | |||||
return np.inf | |||||
def _accept_new_acq_value(self, opt_acq, temp_acq_value): | |||||
if temp_acq_value > opt_acq and self.optimizemode is OptimizeMode.Maximize: | |||||
return True | |||||
if temp_acq_value < opt_acq and not self.optimizemode is OptimizeMode.Maximize: | |||||
return True | |||||
return False | |||||
def add_child(self, father_id, model_id): | |||||
''' add child to the search tree | |||||
Arguments: | |||||
father_id {int} -- father id | |||||
model_id {int} -- model id | |||||
''' | |||||
self.search_tree.add_child(father_id, model_id) | |||||
@total_ordering | |||||
class Elem: | |||||
"""Elements to be sorted according to metric value.""" | |||||
def __init__(self, metric_value, father_id, graph): | |||||
self.father_id = father_id | |||||
self.graph = graph | |||||
self.metric_value = metric_value | |||||
def __eq__(self, other): | |||||
return self.metric_value == other.metric_value | |||||
def __lt__(self, other): | |||||
return self.metric_value < other.metric_value | |||||
class ReverseElem(Elem): | |||||
"""Elements to be reversely sorted according to metric value.""" | |||||
def __lt__(self, other): | |||||
return self.metric_value > other.metric_value | |||||
def contain(descriptors, target_descriptor): | |||||
"""Check if the target descriptor is in the descriptors.""" | |||||
for descriptor in descriptors: | |||||
if edit_distance(descriptor, target_descriptor) < 1e-5: | |||||
return True | |||||
return False | |||||
class SearchTree: | |||||
"""The network morphism search tree.""" | |||||
def __init__(self): | |||||
self.root = None | |||||
self.adj_list = {} | |||||
def add_child(self, u, v): | |||||
''' add child to search tree itself. | |||||
Arguments: | |||||
u {int} -- father id | |||||
v {int} -- child id | |||||
''' | |||||
if u == -1: | |||||
self.root = v | |||||
self.adj_list[v] = [] | |||||
return | |||||
if v not in self.adj_list[u]: | |||||
self.adj_list[u].append(v) | |||||
if v not in self.adj_list: | |||||
self.adj_list[v] = [] | |||||
def get_dict(self, u=None): | |||||
""" A recursive function to return the content of the tree in a dict.""" | |||||
if u is None: | |||||
return self.get_dict(self.root) | |||||
children = [] | |||||
for v in self.adj_list[u]: | |||||
children.append(self.get_dict(v)) | |||||
ret = {"name": u, "children": children} | |||||
return ret |
@@ -0,0 +1,919 @@ | |||||
import json | |||||
from collections.abc import Iterable | |||||
from copy import deepcopy, copy | |||||
from queue import Queue | |||||
import numpy as np | |||||
import torch | |||||
from .layer_transformer import ( | |||||
add_noise, | |||||
wider_bn, | |||||
wider_next_conv, | |||||
wider_next_dense, | |||||
wider_pre_conv, | |||||
wider_pre_dense, | |||||
init_dense_weight, | |||||
init_conv_weight, | |||||
init_bn_weight, | |||||
) | |||||
from .layers import ( | |||||
StubAdd, | |||||
StubConcatenate, | |||||
StubReLU, | |||||
get_batch_norm_class, | |||||
get_conv_class, | |||||
is_layer, | |||||
layer_width, | |||||
set_stub_weight_to_torch, | |||||
set_torch_weight_to_stub, | |||||
layer_description_extractor, | |||||
layer_description_builder, | |||||
) | |||||
from utils import Constant | |||||
class NetworkDescriptor: | |||||
"""A class describing the neural architecture for neural network kernel. | |||||
It only record the width of convolutional and dense layers, and the skip-connection types and positions. | |||||
""" | |||||
CONCAT_CONNECT = "concat" | |||||
ADD_CONNECT = "add" | |||||
def __init__(self): | |||||
self.skip_connections = [] | |||||
self.layers = [] | |||||
@property | |||||
def n_layers(self): | |||||
return len(self.layers) | |||||
def add_skip_connection(self, u, v, connection_type): | |||||
""" Add a skip-connection to the descriptor. | |||||
Args: | |||||
u: Number of convolutional layers before the starting point. | |||||
v: Number of convolutional layers before the ending point. | |||||
connection_type: Must be either CONCAT_CONNECT or ADD_CONNECT. | |||||
""" | |||||
if connection_type not in [self.CONCAT_CONNECT, self.ADD_CONNECT]: | |||||
raise ValueError( | |||||
"connection_type should be NetworkDescriptor.CONCAT_CONNECT " | |||||
"or NetworkDescriptor.ADD_CONNECT." | |||||
) | |||||
self.skip_connections.append((u, v, connection_type)) | |||||
def to_json(self): | |||||
''' NetworkDescriptor to json representation | |||||
''' | |||||
skip_list = [] | |||||
for u, v, connection_type in self.skip_connections: | |||||
skip_list.append({"from": u, "to": v, "type": connection_type}) | |||||
return {"node_list": self.layers, "skip_list": skip_list} | |||||
def add_layer(self, layer): | |||||
''' add one layer | |||||
''' | |||||
self.layers.append(layer) | |||||
class Node: | |||||
"""A class for intermediate output tensor (node) in the Graph. | |||||
Attributes: | |||||
shape: A tuple describing the shape of the tensor. | |||||
""" | |||||
def __init__(self, shape): | |||||
self.shape = shape | |||||
class Graph: | |||||
"""A class representing the neural architecture graph of a model. | |||||
Graph extracts the neural architecture graph from a model. | |||||
Each node in the graph is a intermediate tensor between layers. | |||||
Each layer is an edge in the graph. | |||||
Notably, multiple edges may refer to the same layer. | |||||
(e.g. Add layer is adding two tensor into one tensor. So it is related to two edges.) | |||||
Attributes: | |||||
weighted: A boolean of whether the weights and biases in the neural network | |||||
should be included in the graph. | |||||
input_shape: A tuple of integers, which does not include the batch axis. | |||||
node_list: A list of integers. The indices of the list are the identifiers. | |||||
layer_list: A list of stub layers. The indices of the list are the identifiers. | |||||
node_to_id: A dict instance mapping from node integers to their identifiers. | |||||
layer_to_id: A dict instance mapping from stub layers to their identifiers. | |||||
layer_id_to_input_node_ids: A dict instance mapping from layer identifiers | |||||
to their input nodes identifiers. | |||||
layer_id_to_output_node_ids: A dict instance mapping from layer identifiers | |||||
to their output nodes identifiers. | |||||
adj_list: A two dimensional list. The adjacency list of the graph. The first dimension is | |||||
identified by tensor identifiers. In each edge list, the elements are two-element tuples | |||||
of (tensor identifier, layer identifier). | |||||
reverse_adj_list: A reverse adjacent list in the same format as adj_list. | |||||
operation_history: A list saving all the network morphism operations. | |||||
vis: A dictionary of temporary storage for whether an local operation has been done | |||||
during the network morphism. | |||||
""" | |||||
def __init__(self, input_shape, weighted=True): | |||||
"""Initializer for Graph. | |||||
""" | |||||
self.input_shape = input_shape | |||||
self.weighted = weighted | |||||
self.node_list = [] | |||||
self.layer_list = [] | |||||
# node id start with 0 | |||||
self.node_to_id = {} | |||||
self.layer_to_id = {} | |||||
self.layer_id_to_input_node_ids = {} | |||||
self.layer_id_to_output_node_ids = {} | |||||
self.adj_list = {} | |||||
self.reverse_adj_list = {} | |||||
self.operation_history = [] | |||||
self.n_dim = len(input_shape) - 1 | |||||
self.conv = get_conv_class(self.n_dim) | |||||
self.batch_norm = get_batch_norm_class(self.n_dim) | |||||
self.vis = None | |||||
self._add_node(Node(input_shape)) | |||||
def add_layer(self, layer, input_node_id): | |||||
"""Add a layer to the Graph. | |||||
Args: | |||||
layer: An instance of the subclasses of StubLayer in layers.py. | |||||
input_node_id: An integer. The ID of the input node of the layer. | |||||
Returns: | |||||
output_node_id: An integer. The ID of the output node of the layer. | |||||
""" | |||||
if isinstance(input_node_id, Iterable): | |||||
layer.input = list(map(lambda x: self.node_list[x], input_node_id)) | |||||
output_node_id = self._add_node(Node(layer.output_shape)) | |||||
for node_id in input_node_id: | |||||
self._add_edge(layer, node_id, output_node_id) | |||||
else: | |||||
layer.input = self.node_list[input_node_id] | |||||
output_node_id = self._add_node(Node(layer.output_shape)) | |||||
self._add_edge(layer, input_node_id, output_node_id) | |||||
layer.output = self.node_list[output_node_id] | |||||
return output_node_id | |||||
def clear_operation_history(self): | |||||
self.operation_history = [] | |||||
@property | |||||
def n_nodes(self): | |||||
"""Return the number of nodes in the model.""" | |||||
return len(self.node_list) | |||||
@property | |||||
def n_layers(self): | |||||
"""Return the number of layers in the model.""" | |||||
return len(self.layer_list) | |||||
def _add_node(self, node): | |||||
"""Add a new node to node_list and give the node an ID. | |||||
Args: | |||||
node: An instance of Node. | |||||
Returns: | |||||
node_id: An integer. | |||||
""" | |||||
node_id = len(self.node_list) | |||||
self.node_to_id[node] = node_id | |||||
self.node_list.append(node) | |||||
self.adj_list[node_id] = [] | |||||
self.reverse_adj_list[node_id] = [] | |||||
return node_id | |||||
def _add_edge(self, layer, input_id, output_id): | |||||
"""Add a new layer to the graph. The nodes should be created in advance.""" | |||||
if layer in self.layer_to_id: | |||||
layer_id = self.layer_to_id[layer] | |||||
if input_id not in self.layer_id_to_input_node_ids[layer_id]: | |||||
self.layer_id_to_input_node_ids[layer_id].append(input_id) | |||||
if output_id not in self.layer_id_to_output_node_ids[layer_id]: | |||||
self.layer_id_to_output_node_ids[layer_id].append(output_id) | |||||
else: | |||||
layer_id = len(self.layer_list) | |||||
self.layer_list.append(layer) | |||||
self.layer_to_id[layer] = layer_id | |||||
self.layer_id_to_input_node_ids[layer_id] = [input_id] | |||||
self.layer_id_to_output_node_ids[layer_id] = [output_id] | |||||
self.adj_list[input_id].append((output_id, layer_id)) | |||||
self.reverse_adj_list[output_id].append((input_id, layer_id)) | |||||
def _redirect_edge(self, u_id, v_id, new_v_id): | |||||
"""Redirect the layer to a new node. | |||||
Change the edge originally from `u_id` to `v_id` into an edge from `u_id` to `new_v_id` | |||||
while keeping all other property of the edge the same. | |||||
""" | |||||
layer_id = None | |||||
for index, edge_tuple in enumerate(self.adj_list[u_id]): | |||||
if edge_tuple[0] == v_id: | |||||
layer_id = edge_tuple[1] | |||||
self.adj_list[u_id][index] = (new_v_id, layer_id) | |||||
self.layer_list[layer_id].output = self.node_list[new_v_id] | |||||
break | |||||
for index, edge_tuple in enumerate(self.reverse_adj_list[v_id]): | |||||
if edge_tuple[0] == u_id: | |||||
layer_id = edge_tuple[1] | |||||
self.reverse_adj_list[v_id].remove(edge_tuple) | |||||
break | |||||
self.reverse_adj_list[new_v_id].append((u_id, layer_id)) | |||||
for index, value in enumerate( | |||||
self.layer_id_to_output_node_ids[layer_id]): | |||||
if value == v_id: | |||||
self.layer_id_to_output_node_ids[layer_id][index] = new_v_id | |||||
break | |||||
def _replace_layer(self, layer_id, new_layer): | |||||
"""Replace the layer with a new layer.""" | |||||
old_layer = self.layer_list[layer_id] | |||||
new_layer.input = old_layer.input | |||||
new_layer.output = old_layer.output | |||||
new_layer.output.shape = new_layer.output_shape | |||||
self.layer_list[layer_id] = new_layer | |||||
self.layer_to_id[new_layer] = layer_id | |||||
self.layer_to_id.pop(old_layer) | |||||
@property | |||||
def topological_order(self): | |||||
"""Return the topological order of the node IDs from the input node to the output node.""" | |||||
q = Queue() | |||||
in_degree = {} | |||||
for i in range(self.n_nodes): | |||||
in_degree[i] = 0 | |||||
for u in range(self.n_nodes): | |||||
for v, _ in self.adj_list[u]: | |||||
in_degree[v] += 1 | |||||
for i in range(self.n_nodes): | |||||
if in_degree[i] == 0: | |||||
q.put(i) | |||||
order_list = [] | |||||
while not q.empty(): | |||||
u = q.get() | |||||
order_list.append(u) | |||||
for v, _ in self.adj_list[u]: | |||||
in_degree[v] -= 1 | |||||
if in_degree[v] == 0: | |||||
q.put(v) | |||||
return order_list | |||||
def _get_pooling_layers(self, start_node_id, end_node_id): | |||||
""" | |||||
Given two node IDs, return all the pooling layers between them. | |||||
Conv layer with strid > 1 is also considered as a Pooling layer. | |||||
""" | |||||
layer_list = [] | |||||
node_list = [start_node_id] | |||||
assert self._depth_first_search(end_node_id, layer_list, node_list) | |||||
ret = [] | |||||
for layer_id in layer_list: | |||||
layer = self.layer_list[layer_id] | |||||
if is_layer(layer, "Pooling"): | |||||
ret.append(layer) | |||||
elif is_layer(layer, "Conv") and layer.stride != 1: | |||||
ret.append(layer) | |||||
return ret | |||||
def _depth_first_search(self, target_id, layer_id_list, node_list): | |||||
"""Search for all the layers and nodes down the path. | |||||
A recursive function to search all the layers and nodes between the node in the node_list | |||||
and the node with target_id.""" | |||||
assert len(node_list) <= self.n_nodes | |||||
u = node_list[-1] | |||||
if u == target_id: | |||||
return True | |||||
for v, layer_id in self.adj_list[u]: | |||||
layer_id_list.append(layer_id) | |||||
node_list.append(v) | |||||
if self._depth_first_search(target_id, layer_id_list, node_list): | |||||
return True | |||||
layer_id_list.pop() | |||||
node_list.pop() | |||||
return False | |||||
def _search(self, u, start_dim, total_dim, n_add): | |||||
"""Search the graph for all the layers to be widened caused by an operation. | |||||
It is an recursive function with duplication check to avoid deadlock. | |||||
It searches from a starting node u until the corresponding layers has been widened. | |||||
Args: | |||||
u: The starting node ID. | |||||
start_dim: The position to insert the additional dimensions. | |||||
total_dim: The total number of dimensions the layer has before widening. | |||||
n_add: The number of dimensions to add. | |||||
""" | |||||
if (u, start_dim, total_dim, n_add) in self.vis: | |||||
return | |||||
self.vis[(u, start_dim, total_dim, n_add)] = True | |||||
for v, layer_id in self.adj_list[u]: | |||||
layer = self.layer_list[layer_id] | |||||
if is_layer(layer, "Conv"): | |||||
new_layer = wider_next_conv( | |||||
layer, start_dim, total_dim, n_add, self.weighted | |||||
) | |||||
self._replace_layer(layer_id, new_layer) | |||||
elif is_layer(layer, "Dense"): | |||||
new_layer = wider_next_dense( | |||||
layer, start_dim, total_dim, n_add, self.weighted | |||||
) | |||||
self._replace_layer(layer_id, new_layer) | |||||
elif is_layer(layer, "BatchNormalization"): | |||||
new_layer = wider_bn( | |||||
layer, start_dim, total_dim, n_add, self.weighted) | |||||
self._replace_layer(layer_id, new_layer) | |||||
self._search(v, start_dim, total_dim, n_add) | |||||
elif is_layer(layer, "Concatenate"): | |||||
if self.layer_id_to_input_node_ids[layer_id][1] == u: | |||||
# u is on the right of the concat | |||||
# next_start_dim += next_total_dim - total_dim | |||||
left_dim = self._upper_layer_width( | |||||
self.layer_id_to_input_node_ids[layer_id][0] | |||||
) | |||||
next_start_dim = start_dim + left_dim | |||||
next_total_dim = total_dim + left_dim | |||||
else: | |||||
next_start_dim = start_dim | |||||
next_total_dim = total_dim + self._upper_layer_width( | |||||
self.layer_id_to_input_node_ids[layer_id][1] | |||||
) | |||||
self._search(v, next_start_dim, next_total_dim, n_add) | |||||
else: | |||||
self._search(v, start_dim, total_dim, n_add) | |||||
for v, layer_id in self.reverse_adj_list[u]: | |||||
layer = self.layer_list[layer_id] | |||||
if is_layer(layer, "Conv"): | |||||
new_layer = wider_pre_conv(layer, n_add, self.weighted) | |||||
self._replace_layer(layer_id, new_layer) | |||||
elif is_layer(layer, "Dense"): | |||||
new_layer = wider_pre_dense(layer, n_add, self.weighted) | |||||
self._replace_layer(layer_id, new_layer) | |||||
elif is_layer(layer, "Concatenate"): | |||||
continue | |||||
else: | |||||
self._search(v, start_dim, total_dim, n_add) | |||||
def _upper_layer_width(self, u): | |||||
for v, layer_id in self.reverse_adj_list[u]: | |||||
layer = self.layer_list[layer_id] | |||||
if is_layer(layer, "Conv") or is_layer(layer, "Dense"): | |||||
return layer_width(layer) | |||||
elif is_layer(layer, "Concatenate"): | |||||
a = self.layer_id_to_input_node_ids[layer_id][0] | |||||
b = self.layer_id_to_input_node_ids[layer_id][1] | |||||
return self._upper_layer_width(a) + self._upper_layer_width(b) | |||||
else: | |||||
return self._upper_layer_width(v) | |||||
return self.node_list[0].shape[-1] | |||||
def to_deeper_model(self, target_id, new_layer): | |||||
"""Insert a relu-conv-bn block after the target block. | |||||
Args: | |||||
target_id: A convolutional layer ID. The new block should be inserted after the block. | |||||
new_layer: An instance of StubLayer subclasses. | |||||
""" | |||||
self.operation_history.append( | |||||
("to_deeper_model", target_id, new_layer)) | |||||
input_id = self.layer_id_to_input_node_ids[target_id][0] | |||||
output_id = self.layer_id_to_output_node_ids[target_id][0] | |||||
if self.weighted: | |||||
if is_layer(new_layer, "Dense"): | |||||
init_dense_weight(new_layer) | |||||
elif is_layer(new_layer, "Conv"): | |||||
init_conv_weight(new_layer) | |||||
elif is_layer(new_layer, "BatchNormalization"): | |||||
init_bn_weight(new_layer) | |||||
self._insert_new_layers([new_layer], input_id, output_id) | |||||
def to_wider_model(self, pre_layer_id, n_add): | |||||
"""Widen the last dimension of the output of the pre_layer. | |||||
Args: | |||||
pre_layer_id: The ID of a convolutional layer or dense layer. | |||||
n_add: The number of dimensions to add. | |||||
""" | |||||
self.operation_history.append(("to_wider_model", pre_layer_id, n_add)) | |||||
pre_layer = self.layer_list[pre_layer_id] | |||||
output_id = self.layer_id_to_output_node_ids[pre_layer_id][0] | |||||
dim = layer_width(pre_layer) | |||||
self.vis = {} | |||||
self._search(output_id, dim, dim, n_add) | |||||
# Update the tensor shapes. | |||||
for u in self.topological_order: | |||||
for v, layer_id in self.adj_list[u]: | |||||
self.node_list[v].shape = self.layer_list[layer_id].output_shape | |||||
def _insert_new_layers(self, new_layers, start_node_id, end_node_id): | |||||
"""Insert the new_layers after the node with start_node_id.""" | |||||
new_node_id = self._add_node(deepcopy(self.node_list[end_node_id])) | |||||
temp_output_id = new_node_id | |||||
for layer in new_layers[:-1]: | |||||
temp_output_id = self.add_layer(layer, temp_output_id) | |||||
self._add_edge(new_layers[-1], temp_output_id, end_node_id) | |||||
new_layers[-1].input = self.node_list[temp_output_id] | |||||
new_layers[-1].output = self.node_list[end_node_id] | |||||
self._redirect_edge(start_node_id, end_node_id, new_node_id) | |||||
def _block_end_node(self, layer_id, block_size): | |||||
ret = self.layer_id_to_output_node_ids[layer_id][0] | |||||
for _ in range(block_size - 2): | |||||
ret = self.adj_list[ret][0][0] | |||||
return ret | |||||
def _dense_block_end_node(self, layer_id): | |||||
return self.layer_id_to_input_node_ids[layer_id][0] | |||||
def _conv_block_end_node(self, layer_id): | |||||
"""Get the input node ID of the last layer in the block by layer ID. | |||||
Return the input node ID of the last layer in the convolutional block. | |||||
Args: | |||||
layer_id: the convolutional layer ID. | |||||
""" | |||||
return self._block_end_node(layer_id, Constant.CONV_BLOCK_DISTANCE) | |||||
def to_add_skip_model(self, start_id, end_id): | |||||
"""Add a weighted add skip-connection from after start node to end node. | |||||
Args: | |||||
start_id: The convolutional layer ID, after which to start the skip-connection. | |||||
end_id: The convolutional layer ID, after which to end the skip-connection. | |||||
""" | |||||
self.operation_history.append(("to_add_skip_model", start_id, end_id)) | |||||
filters_end = self.layer_list[end_id].output.shape[-1] | |||||
filters_start = self.layer_list[start_id].output.shape[-1] | |||||
start_node_id = self.layer_id_to_output_node_ids[start_id][0] | |||||
pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] | |||||
end_node_id = self.layer_id_to_output_node_ids[end_id][0] | |||||
skip_output_id = self._insert_pooling_layer_chain( | |||||
start_node_id, end_node_id) | |||||
# Add the conv layer in order to align the number of channels with end layer id | |||||
new_conv_layer = get_conv_class( | |||||
self.n_dim)( | |||||
filters_start, | |||||
filters_end, | |||||
1) | |||||
skip_output_id = self.add_layer(new_conv_layer, skip_output_id) | |||||
# Add the add layer. | |||||
add_input_node_id = self._add_node( | |||||
deepcopy(self.node_list[end_node_id])) | |||||
add_layer = StubAdd() | |||||
self._redirect_edge(pre_end_node_id, end_node_id, add_input_node_id) | |||||
self._add_edge(add_layer, add_input_node_id, end_node_id) | |||||
self._add_edge(add_layer, skip_output_id, end_node_id) | |||||
add_layer.input = [ | |||||
self.node_list[add_input_node_id], | |||||
self.node_list[skip_output_id], | |||||
] | |||||
add_layer.output = self.node_list[end_node_id] | |||||
self.node_list[end_node_id].shape = add_layer.output_shape | |||||
# Set weights to the additional conv layer. | |||||
if self.weighted: | |||||
filter_shape = (1,) * self.n_dim | |||||
weights = np.zeros((filters_end, filters_start) + filter_shape) | |||||
bias = np.zeros(filters_end) | |||||
new_conv_layer.set_weights( | |||||
(add_noise(weights, np.array([0, 1])), add_noise( | |||||
bias, np.array([0, 1]))) | |||||
) | |||||
def to_concat_skip_model(self, start_id, end_id): | |||||
"""Add a weighted add concatenate connection from after start node to end node. | |||||
Args: | |||||
start_id: The convolutional layer ID, after which to start the skip-connection. | |||||
end_id: The convolutional layer ID, after which to end the skip-connection. | |||||
""" | |||||
self.operation_history.append( | |||||
("to_concat_skip_model", start_id, end_id)) | |||||
filters_end = self.layer_list[end_id].output.shape[-1] | |||||
filters_start = self.layer_list[start_id].output.shape[-1] | |||||
start_node_id = self.layer_id_to_output_node_ids[start_id][0] | |||||
pre_end_node_id = self.layer_id_to_input_node_ids[end_id][0] | |||||
end_node_id = self.layer_id_to_output_node_ids[end_id][0] | |||||
skip_output_id = self._insert_pooling_layer_chain( | |||||
start_node_id, end_node_id) | |||||
concat_input_node_id = self._add_node( | |||||
deepcopy(self.node_list[end_node_id])) | |||||
self._redirect_edge(pre_end_node_id, end_node_id, concat_input_node_id) | |||||
concat_layer = StubConcatenate() | |||||
concat_layer.input = [ | |||||
self.node_list[concat_input_node_id], | |||||
self.node_list[skip_output_id], | |||||
] | |||||
concat_output_node_id = self._add_node(Node(concat_layer.output_shape)) | |||||
self._add_edge( | |||||
concat_layer, | |||||
concat_input_node_id, | |||||
concat_output_node_id) | |||||
self._add_edge(concat_layer, skip_output_id, concat_output_node_id) | |||||
concat_layer.output = self.node_list[concat_output_node_id] | |||||
self.node_list[concat_output_node_id].shape = concat_layer.output_shape | |||||
# Add the concatenate layer. | |||||
# concat过channel数增加,用conv class 回到原先的channel数 | |||||
new_conv_layer = get_conv_class(self.n_dim)( | |||||
filters_start + filters_end, filters_end, 1 | |||||
) | |||||
self._add_edge(new_conv_layer, concat_output_node_id, end_node_id) | |||||
new_conv_layer.input = self.node_list[concat_output_node_id] | |||||
new_conv_layer.output = self.node_list[end_node_id] | |||||
self.node_list[end_node_id].shape = new_conv_layer.output_shape | |||||
if self.weighted: | |||||
filter_shape = (1,) * self.n_dim | |||||
weights = np.zeros((filters_end, filters_end) + filter_shape) | |||||
for i in range(filters_end): | |||||
filter_weight = np.zeros((filters_end,) + filter_shape) | |||||
center_index = (i,) + (0,) * self.n_dim | |||||
filter_weight[center_index] = 1 | |||||
weights[i, ...] = filter_weight | |||||
weights = np.concatenate( | |||||
(weights, np.zeros((filters_end, filters_start) + filter_shape)), axis=1 | |||||
) | |||||
bias = np.zeros(filters_end) | |||||
new_conv_layer.set_weights( | |||||
(add_noise(weights, np.array([0, 1])), add_noise( | |||||
bias, np.array([0, 1]))) | |||||
) | |||||
def _insert_pooling_layer_chain(self, start_node_id, end_node_id): | |||||
""" | |||||
insert pooling layer | |||||
""" | |||||
skip_output_id = start_node_id | |||||
# 得到从start_node_id 到 end_node_id之间的所有pooling layer(包括conv layer stride > 1) | |||||
for layer in self._get_pooling_layers(start_node_id, end_node_id): | |||||
new_layer = deepcopy(layer) | |||||
# 如果是conv层需要重新初始化weights | |||||
if is_layer(new_layer, "Conv"): | |||||
# start node id 的通道数 | |||||
filters = self.node_list[start_node_id].shape[-1] | |||||
new_layer = get_conv_class(self.n_dim)( | |||||
filters, filters, 1, layer.stride) | |||||
if self.weighted: | |||||
init_conv_weight(new_layer) | |||||
else: | |||||
new_layer = deepcopy(layer) | |||||
skip_output_id = self.add_layer(new_layer, skip_output_id) | |||||
skip_output_id = self.add_layer(StubReLU(), skip_output_id) | |||||
return skip_output_id | |||||
def extract_descriptor(self): | |||||
"""Extract the the description of the Graph as an instance of NetworkDescriptor.""" | |||||
main_chain = self.get_main_chain() | |||||
index_in_main_chain = {} | |||||
for index, u in enumerate(main_chain): | |||||
index_in_main_chain[u] = index | |||||
ret = NetworkDescriptor() | |||||
for u in main_chain: | |||||
for v, layer_id in self.adj_list[u]: | |||||
if v not in index_in_main_chain: | |||||
continue | |||||
layer = self.layer_list[layer_id] | |||||
copied_layer = copy(layer) | |||||
copied_layer.weights = None | |||||
ret.add_layer(deepcopy(copied_layer)) | |||||
for u in index_in_main_chain: | |||||
for v, layer_id in self.adj_list[u]: | |||||
if v not in index_in_main_chain: | |||||
temp_u = u | |||||
temp_v = v | |||||
temp_layer_id = layer_id | |||||
skip_type = None | |||||
while not ( | |||||
temp_v in index_in_main_chain and temp_u in index_in_main_chain): | |||||
if is_layer( | |||||
self.layer_list[temp_layer_id], "Concatenate"): | |||||
skip_type = NetworkDescriptor.CONCAT_CONNECT | |||||
if is_layer(self.layer_list[temp_layer_id], "Add"): | |||||
skip_type = NetworkDescriptor.ADD_CONNECT | |||||
temp_u = temp_v | |||||
temp_v, temp_layer_id = self.adj_list[temp_v][0] | |||||
ret.add_skip_connection( | |||||
index_in_main_chain[u], index_in_main_chain[temp_u], skip_type | |||||
) | |||||
elif index_in_main_chain[v] - index_in_main_chain[u] != 1: | |||||
skip_type = None | |||||
if is_layer(self.layer_list[layer_id], "Concatenate"): | |||||
skip_type = NetworkDescriptor.CONCAT_CONNECT | |||||
if is_layer(self.layer_list[layer_id], "Add"): | |||||
skip_type = NetworkDescriptor.ADD_CONNECT | |||||
ret.add_skip_connection( | |||||
index_in_main_chain[u], index_in_main_chain[v], skip_type | |||||
) | |||||
return ret | |||||
def clear_weights(self): | |||||
''' clear weights of the graph | |||||
''' | |||||
self.weighted = False | |||||
for layer in self.layer_list: | |||||
layer.weights = None | |||||
def produce_torch_model(self): | |||||
"""Build a new Torch model based on the current graph.""" | |||||
return TorchModel(self) | |||||
def produce_json_model(self): | |||||
"""Build a new Json model based on the current graph.""" | |||||
return JSONModel(self).data | |||||
@classmethod | |||||
def parsing_json_model(cls, json_model): | |||||
'''build a graph from json | |||||
''' | |||||
return json_to_graph(json_model) | |||||
def _layer_ids_in_order(self, layer_ids): | |||||
node_id_to_order_index = {} | |||||
for index, node_id in enumerate(self.topological_order): | |||||
node_id_to_order_index[node_id] = index | |||||
return sorted( | |||||
layer_ids, | |||||
key=lambda layer_id: node_id_to_order_index[ | |||||
self.layer_id_to_output_node_ids[layer_id][0] | |||||
], | |||||
) | |||||
def _layer_ids_by_type(self, type_str): | |||||
return list( | |||||
filter( | |||||
lambda layer_id: is_layer(self.layer_list[layer_id], type_str), | |||||
range(self.n_layers), | |||||
) | |||||
) | |||||
def get_main_chain_layers(self): | |||||
"""Return a list of layer IDs in the main chain.""" | |||||
main_chain = self.get_main_chain() | |||||
ret = [] | |||||
for u in main_chain: | |||||
for v, layer_id in self.adj_list[u]: | |||||
if v in main_chain and u in main_chain: | |||||
ret.append(layer_id) | |||||
return ret | |||||
def _conv_layer_ids_in_order(self): | |||||
return list( | |||||
filter( | |||||
lambda layer_id: is_layer(self.layer_list[layer_id], "Conv"), | |||||
self.get_main_chain_layers(), | |||||
) | |||||
) | |||||
def _dense_layer_ids_in_order(self): | |||||
return self._layer_ids_in_order(self._layer_ids_by_type("Dense")) | |||||
def deep_layer_ids(self): | |||||
ret = [] | |||||
for layer_id in self.get_main_chain_layers(): | |||||
layer = self.layer_list[layer_id] | |||||
# GAP之后就不插入layer了 | |||||
if is_layer(layer, "GlobalAveragePooling"): | |||||
break | |||||
if is_layer(layer, "Add") or is_layer(layer, "Concatenate"): | |||||
continue | |||||
ret.append(layer_id) | |||||
return ret | |||||
def wide_layer_ids(self): | |||||
return ( | |||||
self._conv_layer_ids_in_order( | |||||
)[:-1] + self._dense_layer_ids_in_order()[:-1] | |||||
) | |||||
def skip_connection_layer_ids(self): | |||||
return self.deep_layer_ids()[:-1] | |||||
def size(self): | |||||
return sum(list(map(lambda x: x.size(), self.layer_list))) | |||||
def get_main_chain(self): | |||||
"""Returns the main chain node ID list.""" | |||||
pre_node = {} | |||||
distance = {} | |||||
# 初始化每个节点距离为0,他的前一个节点为自己 | |||||
for i in range(self.n_nodes): | |||||
distance[i] = 0 | |||||
pre_node[i] = i | |||||
# 遍历所有节点,根据邻接表找到他的前一个节点以及他本身的位置 | |||||
for i in range(self.n_nodes - 1): | |||||
for u in range(self.n_nodes): | |||||
for v, _ in self.adj_list[u]: | |||||
if distance[u] + 1 > distance[v]: | |||||
distance[v] = distance[u] + 1 | |||||
pre_node[v] = u | |||||
# temp_id记录距离最大的node | |||||
temp_id = 0 | |||||
for i in range(self.n_nodes): | |||||
if distance[i] > distance[temp_id]: | |||||
temp_id = i | |||||
# 从距离最大的node开始不断找到他的前一个节点,最终找到主链 | |||||
ret = [] | |||||
for i in range(self.n_nodes + 5): | |||||
ret.append(temp_id) | |||||
if pre_node[temp_id] == temp_id: | |||||
break | |||||
temp_id = pre_node[temp_id] | |||||
assert temp_id == pre_node[temp_id] | |||||
ret.reverse() | |||||
return ret | |||||
class TorchModel(torch.nn.Module): | |||||
"""A neural network class using pytorch constructed from an instance of Graph.""" | |||||
def __init__(self, graph): | |||||
super(TorchModel, self).__init__() | |||||
self.graph = graph | |||||
self.layers = torch.nn.ModuleList() | |||||
for layer in graph.layer_list: | |||||
self.layers.append(layer.to_real_layer()) | |||||
if graph.weighted: | |||||
for index, layer in enumerate(self.layers): | |||||
set_stub_weight_to_torch(self.graph.layer_list[index], layer) | |||||
for index, layer in enumerate(self.layers): | |||||
self.add_module(str(index), layer) | |||||
def forward(self, input_tensor): | |||||
topo_node_list = self.graph.topological_order | |||||
output_id = topo_node_list[-1] | |||||
input_id = topo_node_list[0] | |||||
node_list = deepcopy(self.graph.node_list) | |||||
node_list[input_id] = input_tensor | |||||
for v in topo_node_list: | |||||
for u, layer_id in self.graph.reverse_adj_list[v]: | |||||
layer = self.graph.layer_list[layer_id] | |||||
torch_layer = self.layers[layer_id] | |||||
if isinstance(layer, (StubAdd, StubConcatenate)): | |||||
edge_input_tensor = list( | |||||
map( | |||||
lambda x: node_list[x], | |||||
self.graph.layer_id_to_input_node_ids[layer_id], | |||||
) | |||||
) | |||||
else: | |||||
edge_input_tensor = node_list[u] | |||||
temp_tensor = torch_layer(edge_input_tensor) | |||||
node_list[v] = temp_tensor | |||||
return node_list[output_id] | |||||
def set_weight_to_graph(self): | |||||
self.graph.weighted = True | |||||
for index, layer in enumerate(self.layers): | |||||
set_torch_weight_to_stub(layer, self.graph.layer_list[index]) | |||||
class JSONModel: | |||||
def __init__(self, graph): | |||||
data = dict() | |||||
node_list = list() | |||||
layer_list = list() | |||||
operation_history = list() | |||||
data["input_shape"] = graph.input_shape | |||||
vis = graph.vis | |||||
data["vis"] = list(vis.keys()) if vis is not None else None | |||||
data["weighted"] = graph.weighted | |||||
for item in graph.operation_history: | |||||
if item[0] == "to_deeper_model": | |||||
operation_history.append( | |||||
[ | |||||
item[0], | |||||
item[1], | |||||
layer_description_extractor(item[2], graph.node_to_id), | |||||
] | |||||
) | |||||
else: | |||||
operation_history.append(item) | |||||
data["operation_history"] = operation_history | |||||
data["layer_id_to_input_node_ids"] = graph.layer_id_to_input_node_ids | |||||
data["layer_id_to_output_node_ids"] = graph.layer_id_to_output_node_ids | |||||
data["adj_list"] = graph.adj_list | |||||
data["reverse_adj_list"] = graph.reverse_adj_list | |||||
for node in graph.node_list: | |||||
node_id = graph.node_to_id[node] | |||||
node_information = node.shape | |||||
node_list.append((node_id, node_information)) | |||||
for layer_id, item in enumerate(graph.layer_list): | |||||
layer = graph.layer_list[layer_id] | |||||
layer_information = layer_description_extractor( | |||||
layer, graph.node_to_id) | |||||
layer_list.append((layer_id, layer_information)) | |||||
data["node_list"] = node_list | |||||
data["layer_list"] = layer_list | |||||
self.data = data | |||||
def graph_to_json(graph, json_model_path): | |||||
json_out = graph.produce_json_model() | |||||
with open(json_model_path, "w") as fout: | |||||
json.dump(json_out, fout) | |||||
json_out = json.dumps(json_out) | |||||
return json_out | |||||
def json_to_graph(json_model: str): | |||||
json_model = json.loads(json_model) | |||||
# restore graph data from json data | |||||
input_shape = tuple(json_model["input_shape"]) | |||||
node_list = list() | |||||
node_to_id = dict() | |||||
id_to_node = dict() | |||||
layer_list = list() | |||||
layer_to_id = dict() | |||||
operation_history = list() | |||||
graph = Graph(input_shape, False) | |||||
graph.input_shape = input_shape | |||||
vis = json_model["vis"] | |||||
graph.vis = { | |||||
tuple(item): True for item in vis} if vis is not None else None | |||||
graph.weighted = json_model["weighted"] | |||||
layer_id_to_input_node_ids = json_model["layer_id_to_input_node_ids"] | |||||
graph.layer_id_to_input_node_ids = { | |||||
int(k): v for k, v in layer_id_to_input_node_ids.items() | |||||
} | |||||
layer_id_to_output_node_ids = json_model["layer_id_to_output_node_ids"] | |||||
graph.layer_id_to_output_node_ids = { | |||||
int(k): v for k, v in layer_id_to_output_node_ids.items() | |||||
} | |||||
adj_list = {} | |||||
for k, v in json_model["adj_list"].items(): | |||||
adj_list[int(k)] = [tuple(i) for i in v] | |||||
graph.adj_list = adj_list | |||||
reverse_adj_list = {} | |||||
for k, v in json_model["reverse_adj_list"].items(): | |||||
reverse_adj_list[int(k)] = [tuple(i) for i in v] | |||||
graph.reverse_adj_list = reverse_adj_list | |||||
for item in json_model["node_list"]: | |||||
new_node = Node(tuple(item[1])) | |||||
node_id = item[0] | |||||
node_list.append(new_node) | |||||
node_to_id[new_node] = node_id | |||||
id_to_node[node_id] = new_node | |||||
for item in json_model["operation_history"]: | |||||
if item[0] == "to_deeper_model": | |||||
operation_history.append( | |||||
(item[0], item[1], layer_description_builder(item[2], id_to_node)) | |||||
) | |||||
else: | |||||
operation_history.append(item) | |||||
graph.operation_history = operation_history | |||||
for item in json_model["layer_list"]: | |||||
new_layer = layer_description_builder(item[1], id_to_node) | |||||
layer_id = int(item[0]) | |||||
layer_list.append(new_layer) | |||||
layer_to_id[new_layer] = layer_id | |||||
graph.node_list = node_list | |||||
graph.node_to_id = node_to_id | |||||
graph.layer_list = layer_list | |||||
graph.layer_to_id = layer_to_id | |||||
return graph |
@@ -0,0 +1,178 @@ | |||||
from copy import deepcopy | |||||
from random import randrange, sample | |||||
from .graph import NetworkDescriptor | |||||
from .layers import ( | |||||
StubDense, | |||||
StubReLU, | |||||
get_batch_norm_class, | |||||
get_conv_class, | |||||
get_dropout_class, | |||||
get_pooling_class, | |||||
is_layer, | |||||
) | |||||
from utils import Constant | |||||
def to_wider_graph(graph): | |||||
''' wider graph | |||||
''' | |||||
weighted_layer_ids = graph.wide_layer_ids() | |||||
weighted_layer_ids = list( | |||||
filter( | |||||
lambda x: graph.layer_list[x].output.shape[-1], weighted_layer_ids) | |||||
) | |||||
wider_layers = sample(weighted_layer_ids, 1) | |||||
# count the number of layers with width larger than the max width | |||||
layer_width_maxed = 0 | |||||
for layer_id in wider_layers: | |||||
layer = graph.layer_list[layer_id] | |||||
if is_layer(layer, "Conv"): | |||||
n_add = layer.filters | |||||
else: | |||||
n_add = layer.units | |||||
if n_add*2 > Constant.MAX_LAYER_WIDTH: | |||||
layer_width_maxed += 1 | |||||
continue | |||||
graph.to_wider_model(layer_id, n_add) | |||||
if layer_width_maxed == len(wider_layers): | |||||
return None | |||||
return graph | |||||
def to_skip_connection_graph(graph): | |||||
''' skip connection graph | |||||
''' | |||||
# The last conv layer cannot be widen since wider operator cannot be done | |||||
# over the two sides of flatten. | |||||
weighted_layer_ids = graph.skip_connection_layer_ids() | |||||
valid_connection = [] | |||||
for skip_type in sorted( | |||||
[NetworkDescriptor.ADD_CONNECT, NetworkDescriptor.CONCAT_CONNECT]): | |||||
for index_a in range(len(weighted_layer_ids)): | |||||
for index_b in range(len(weighted_layer_ids))[index_a + 1:]: | |||||
valid_connection.append((index_a, index_b, skip_type)) | |||||
if len(valid_connection) < 1: | |||||
return graph | |||||
for index_a, index_b, skip_type in sample(valid_connection, 1): | |||||
a_id = weighted_layer_ids[index_a] | |||||
b_id = weighted_layer_ids[index_b] | |||||
if skip_type == NetworkDescriptor.ADD_CONNECT: | |||||
graph.to_add_skip_model(a_id, b_id) | |||||
else: | |||||
graph.to_concat_skip_model(a_id, b_id) | |||||
return graph | |||||
def create_new_layer(layer, n_dim): | |||||
''' create new layer for the graph | |||||
''' | |||||
input_shape = layer.output.shape | |||||
# 一般情况 | |||||
dense_deeper_classes = [StubDense, get_dropout_class(n_dim), StubReLU] | |||||
conv_deeper_classes = [ | |||||
get_conv_class(n_dim), | |||||
get_batch_norm_class(n_dim), | |||||
StubReLU] | |||||
# 三种情况有特别的layer class | |||||
if is_layer(layer, "ReLU"): | |||||
conv_deeper_classes = [ | |||||
get_conv_class(n_dim), | |||||
get_batch_norm_class(n_dim)] | |||||
dense_deeper_classes = [StubDense, get_dropout_class(n_dim)] | |||||
elif is_layer(layer, "Dropout"): | |||||
dense_deeper_classes = [StubDense, StubReLU] | |||||
elif is_layer(layer, "BatchNormalization"): | |||||
conv_deeper_classes = [get_conv_class(n_dim), StubReLU] | |||||
layer_class = None | |||||
if len(input_shape) == 1: | |||||
# It is in the dense layer part. | |||||
layer_class = sample(dense_deeper_classes, 1)[0] | |||||
else: | |||||
# It is in the conv layer part. | |||||
layer_class = sample(conv_deeper_classes, 1)[0] | |||||
if layer_class == StubDense: | |||||
new_layer = StubDense(input_shape[0], input_shape[0]) | |||||
elif layer_class == get_dropout_class(n_dim): | |||||
new_layer = layer_class(Constant.DENSE_DROPOUT_RATE) | |||||
elif layer_class == get_conv_class(n_dim): | |||||
new_layer = layer_class( | |||||
input_shape[-1], input_shape[-1], sample((1, 3, 5), 1)[0], stride=1 | |||||
) | |||||
elif layer_class == get_batch_norm_class(n_dim): | |||||
new_layer = layer_class(input_shape[-1]) | |||||
elif layer_class == get_pooling_class(n_dim): | |||||
new_layer = layer_class(sample((1, 3, 5), 1)[0]) | |||||
else: | |||||
new_layer = layer_class() | |||||
return new_layer | |||||
def to_deeper_graph(graph): | |||||
''' deeper graph | |||||
''' | |||||
weighted_layer_ids = graph.deep_layer_ids() | |||||
if len(weighted_layer_ids) >= Constant.MAX_LAYERS: | |||||
return None | |||||
deeper_layer_ids = sample(weighted_layer_ids, 1) | |||||
for layer_id in deeper_layer_ids: | |||||
layer = graph.layer_list[layer_id] | |||||
new_layer = create_new_layer(layer, graph.n_dim) | |||||
graph.to_deeper_model(layer_id, new_layer) | |||||
return graph | |||||
def legal_graph(graph): | |||||
'''judge if a graph is legal or not. | |||||
''' | |||||
descriptor = graph.extract_descriptor() | |||||
skips = descriptor.skip_connections | |||||
if len(skips) != len(set(skips)): | |||||
return False | |||||
return True | |||||
# morph f with operations in O | |||||
def transform(graph): | |||||
'''core transform function for graph. | |||||
''' | |||||
graphs = [] | |||||
for _ in range(Constant.N_NEIGHBOURS * 2): | |||||
random_num = randrange(3) | |||||
temp_graph = None | |||||
if random_num == 0: | |||||
temp_graph = to_deeper_graph(deepcopy(graph)) | |||||
elif random_num == 1: | |||||
temp_graph = to_wider_graph(deepcopy(graph)) | |||||
elif random_num == 2: | |||||
temp_graph = to_skip_connection_graph(deepcopy(graph)) | |||||
if temp_graph is not None and temp_graph.size() <= Constant.MAX_MODEL_SIZE: | |||||
graphs.append(temp_graph) | |||||
# 最多8次操作 | |||||
if len(graphs) >= Constant.N_NEIGHBOURS: | |||||
break | |||||
return graphs |
@@ -0,0 +1,213 @@ | |||||
import numpy as np | |||||
from .layers import ( | |||||
StubDense, | |||||
get_batch_norm_class, | |||||
get_conv_class, | |||||
get_n_dim, | |||||
) | |||||
NOISE_RATIO = 1e-4 | |||||
def wider_pre_dense(layer, n_add, weighted=True): | |||||
'''wider previous dense layer. | |||||
''' | |||||
if not weighted: | |||||
return StubDense(layer.input_units, layer.units + n_add) | |||||
n_units2 = layer.units | |||||
teacher_w, teacher_b = layer.get_weights() | |||||
rand = np.random.randint(n_units2, size=n_add) | |||||
student_w = teacher_w.copy() | |||||
student_b = teacher_b.copy() | |||||
# target layer update (i) | |||||
for i in range(n_add): | |||||
teacher_index = rand[i] | |||||
new_weight = teacher_w[teacher_index, :] | |||||
new_weight = new_weight[np.newaxis, :] | |||||
student_w = np.concatenate( | |||||
(student_w, add_noise(new_weight, student_w)), axis=0) | |||||
student_b = np.append( | |||||
student_b, add_noise( | |||||
teacher_b[teacher_index], student_b)) | |||||
new_pre_layer = StubDense(layer.input_units, n_units2 + n_add) | |||||
new_pre_layer.set_weights((student_w, student_b)) | |||||
return new_pre_layer | |||||
def wider_pre_conv(layer, n_add_filters, weighted=True): | |||||
'''wider previous conv layer. | |||||
''' | |||||
n_dim = get_n_dim(layer) | |||||
if not weighted: | |||||
return get_conv_class(n_dim)( | |||||
layer.input_channel, | |||||
layer.filters + n_add_filters, | |||||
kernel_size=layer.kernel_size, | |||||
stride=layer.stride | |||||
) | |||||
n_pre_filters = layer.filters | |||||
rand = np.random.randint(n_pre_filters, size=n_add_filters) | |||||
teacher_w, teacher_b = layer.get_weights() | |||||
student_w = teacher_w.copy() | |||||
student_b = teacher_b.copy() | |||||
# target layer update (i) | |||||
for i in range(len(rand)): | |||||
teacher_index = rand[i] | |||||
new_weight = teacher_w[teacher_index, ...] | |||||
new_weight = new_weight[np.newaxis, ...] | |||||
student_w = np.concatenate((student_w, new_weight), axis=0) | |||||
student_b = np.append(student_b, teacher_b[teacher_index]) | |||||
new_pre_layer = get_conv_class(n_dim)( | |||||
layer.input_channel, | |||||
n_pre_filters + n_add_filters, | |||||
kernel_size=layer.kernel_size, | |||||
stride=layer.stride | |||||
) | |||||
new_pre_layer.set_weights( | |||||
(add_noise(student_w, teacher_w), add_noise(student_b, teacher_b)) | |||||
) | |||||
return new_pre_layer | |||||
def wider_next_conv(layer, start_dim, total_dim, n_add, weighted=True): | |||||
'''wider next conv layer. | |||||
''' | |||||
n_dim = get_n_dim(layer) | |||||
if not weighted: | |||||
return get_conv_class(n_dim)(layer.input_channel + n_add, | |||||
layer.filters, | |||||
kernel_size=layer.kernel_size, | |||||
stride=layer.stride) | |||||
n_filters = layer.filters | |||||
teacher_w, teacher_b = layer.get_weights() | |||||
new_weight_shape = list(teacher_w.shape) | |||||
new_weight_shape[1] = n_add | |||||
new_weight = np.zeros(tuple(new_weight_shape)) | |||||
student_w = np.concatenate((teacher_w[:, :start_dim, ...].copy(), | |||||
add_noise(new_weight, teacher_w), | |||||
teacher_w[:, start_dim:total_dim, ...].copy()), axis=1) | |||||
new_layer = get_conv_class(n_dim)(layer.input_channel + n_add, | |||||
n_filters, | |||||
kernel_size=layer.kernel_size, | |||||
stride=layer.stride) | |||||
new_layer.set_weights((student_w, teacher_b)) | |||||
return new_layer | |||||
def wider_bn(layer, start_dim, total_dim, n_add, weighted=True): | |||||
'''wider batch norm layer. | |||||
''' | |||||
n_dim = get_n_dim(layer) | |||||
if not weighted: | |||||
return get_batch_norm_class(n_dim)(layer.num_features + n_add) | |||||
weights = layer.get_weights() | |||||
new_weights = [ | |||||
add_noise(np.ones(n_add, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.zeros(n_add, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.zeros(n_add, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.ones(n_add, dtype=np.float32), np.array([0, 1])), | |||||
] | |||||
student_w = tuple() | |||||
for weight, new_weight in zip(weights, new_weights): | |||||
temp_w = weight.copy() | |||||
temp_w = np.concatenate( | |||||
(temp_w[:start_dim], new_weight, temp_w[start_dim:total_dim]) | |||||
) | |||||
student_w += (temp_w,) | |||||
new_layer = get_batch_norm_class(n_dim)(layer.num_features + n_add) | |||||
new_layer.set_weights(student_w) | |||||
return new_layer | |||||
def wider_next_dense(layer, start_dim, total_dim, n_add, weighted=True): | |||||
'''wider next dense layer. | |||||
''' | |||||
if not weighted: | |||||
return StubDense(layer.input_units + n_add, layer.units) | |||||
teacher_w, teacher_b = layer.get_weights() | |||||
student_w = teacher_w.copy() | |||||
n_units_each_channel = int(teacher_w.shape[1] / total_dim) | |||||
new_weight = np.zeros((teacher_w.shape[0], n_add * n_units_each_channel)) | |||||
student_w = np.concatenate( | |||||
( | |||||
student_w[:, : start_dim * n_units_each_channel], | |||||
add_noise(new_weight, student_w), | |||||
student_w[ | |||||
:, start_dim * n_units_each_channel: total_dim * n_units_each_channel | |||||
], | |||||
), | |||||
axis=1, | |||||
) | |||||
new_layer = StubDense(layer.input_units + n_add, layer.units) | |||||
new_layer.set_weights((student_w, teacher_b)) | |||||
return new_layer | |||||
def add_noise(weights, other_weights): | |||||
'''add noise to the layer. | |||||
''' | |||||
w_range = np.ptp(other_weights.flatten()) | |||||
noise_range = NOISE_RATIO * w_range | |||||
noise = np.random.uniform(-noise_range / 2.0, | |||||
noise_range / 2.0, weights.shape) | |||||
return np.add(noise, weights) | |||||
def init_dense_weight(layer): | |||||
'''initilize dense layer weight. | |||||
''' | |||||
units = layer.units | |||||
weight = np.eye(units) | |||||
bias = np.zeros(units) | |||||
layer.set_weights( | |||||
(add_noise(weight, np.array([0, 1])), | |||||
add_noise(bias, np.array([0, 1]))) | |||||
) | |||||
def init_conv_weight(layer): | |||||
'''initilize conv layer weight. | |||||
''' | |||||
n_filters = layer.filters | |||||
filter_shape = (layer.kernel_size,) * get_n_dim(layer) | |||||
weight = np.zeros((n_filters, n_filters) + filter_shape) | |||||
center = tuple(map(lambda x: int((x - 1) / 2), filter_shape)) | |||||
for i in range(n_filters): | |||||
filter_weight = np.zeros((n_filters,) + filter_shape) | |||||
index = (i,) + center | |||||
filter_weight[index] = 1 | |||||
weight[i, ...] = filter_weight | |||||
bias = np.zeros(n_filters) | |||||
layer.set_weights( | |||||
(add_noise(weight, np.array([0, 1])), | |||||
add_noise(bias, np.array([0, 1]))) | |||||
) | |||||
def init_bn_weight(layer): | |||||
'''initilize batch norm layer weight. | |||||
''' | |||||
n_filters = layer.num_features | |||||
new_weights = [ | |||||
add_noise(np.ones(n_filters, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.zeros(n_filters, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.zeros(n_filters, dtype=np.float32), np.array([0, 1])), | |||||
add_noise(np.ones(n_filters, dtype=np.float32), np.array([0, 1])), | |||||
] | |||||
layer.set_weights(new_weights) |
@@ -0,0 +1,765 @@ | |||||
from abc import abstractmethod | |||||
from collections.abc import Iterable | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional | |||||
from utils import Constant | |||||
class AvgPool(nn.Module): | |||||
""" | |||||
AvgPool Module. | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
@abstractmethod | |||||
def forward(self, input_tensor): | |||||
pass | |||||
class GlobalAvgPool1d(AvgPool): | |||||
""" | |||||
GlobalAvgPool1d Module. | |||||
""" | |||||
def forward(self, input_tensor): | |||||
return functional.avg_pool1d(input_tensor, input_tensor.size()[2:]).view( | |||||
input_tensor.size()[:2] | |||||
) | |||||
class GlobalAvgPool2d(AvgPool): | |||||
""" | |||||
GlobalAvgPool2d Module. | |||||
""" | |||||
def forward(self, input_tensor): | |||||
return functional.avg_pool2d(input_tensor, input_tensor.size()[2:]).view( | |||||
input_tensor.size()[:2] | |||||
) | |||||
class GlobalAvgPool3d(AvgPool): | |||||
""" | |||||
GlobalAvgPool3d Module. | |||||
""" | |||||
def forward(self, input_tensor): | |||||
return functional.avg_pool3d(input_tensor, input_tensor.size()[2:]).view( | |||||
input_tensor.size()[:2] | |||||
) | |||||
class StubLayer: | |||||
""" | |||||
StubLayer Module. Base Module. | |||||
""" | |||||
def __init__(self, input_node=None, output_node=None): | |||||
self.input = input_node | |||||
self.output = output_node | |||||
self.weights = None | |||||
def build(self, shape): | |||||
""" | |||||
build shape. | |||||
""" | |||||
def set_weights(self, weights): | |||||
""" | |||||
set weights. | |||||
""" | |||||
self.weights = weights | |||||
def import_weights(self, torch_layer): | |||||
""" | |||||
import weights. | |||||
""" | |||||
def export_weights(self, torch_layer): | |||||
""" | |||||
export weights. | |||||
""" | |||||
def get_weights(self): | |||||
""" | |||||
get weights. | |||||
""" | |||||
return self.weights | |||||
def size(self): | |||||
""" | |||||
size(). | |||||
""" | |||||
return 0 | |||||
@property | |||||
def output_shape(self): | |||||
""" | |||||
output shape. | |||||
""" | |||||
return self.input.shape | |||||
def to_real_layer(self): | |||||
""" | |||||
to real layer. | |||||
""" | |||||
def __str__(self): | |||||
""" | |||||
str() function to print. | |||||
""" | |||||
return type(self).__name__[4:] | |||||
class StubWeightBiasLayer(StubLayer): | |||||
""" | |||||
StubWeightBiasLayer Module to set the bias. | |||||
""" | |||||
def import_weights(self, torch_layer): | |||||
self.set_weights( | |||||
(torch_layer.weight.data.cpu().numpy(), | |||||
torch_layer.bias.data.cpu().numpy()) | |||||
) | |||||
def export_weights(self, torch_layer): | |||||
torch_layer.weight.data = torch.Tensor(self.weights[0]) | |||||
torch_layer.bias.data = torch.Tensor(self.weights[1]) | |||||
class StubBatchNormalization(StubWeightBiasLayer): | |||||
""" | |||||
StubBatchNormalization Module. Batch Norm. | |||||
""" | |||||
def __init__(self, num_features, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
self.num_features = num_features | |||||
def import_weights(self, torch_layer): | |||||
self.set_weights( | |||||
( | |||||
torch_layer.weight.data.cpu().numpy(), | |||||
torch_layer.bias.data.cpu().numpy(), | |||||
torch_layer.running_mean.cpu().numpy(), | |||||
torch_layer.running_var.cpu().numpy(), | |||||
) | |||||
) | |||||
def export_weights(self, torch_layer): | |||||
torch_layer.weight.data = torch.Tensor(self.weights[0]) | |||||
torch_layer.bias.data = torch.Tensor(self.weights[1]) | |||||
torch_layer.running_mean = torch.Tensor(self.weights[2]) | |||||
torch_layer.running_var = torch.Tensor(self.weights[3]) | |||||
def size(self): | |||||
return self.num_features * 4 | |||||
@abstractmethod | |||||
def to_real_layer(self): | |||||
pass | |||||
class StubBatchNormalization1d(StubBatchNormalization): | |||||
""" | |||||
StubBatchNormalization1d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.BatchNorm1d(self.num_features) | |||||
class StubBatchNormalization2d(StubBatchNormalization): | |||||
""" | |||||
StubBatchNormalization2d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.BatchNorm2d(self.num_features) | |||||
class StubBatchNormalization3d(StubBatchNormalization): | |||||
""" | |||||
StubBatchNormalization3d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.BatchNorm3d(self.num_features) | |||||
class StubDense(StubWeightBiasLayer): | |||||
""" | |||||
StubDense Module. Linear. | |||||
""" | |||||
def __init__(self, input_units, units, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
self.input_units = input_units | |||||
self.units = units | |||||
@property | |||||
def output_shape(self): | |||||
return (self.units,) | |||||
def size(self): | |||||
return self.input_units * self.units + self.units | |||||
def to_real_layer(self): | |||||
return torch.nn.Linear(self.input_units, self.units) | |||||
class StubConv(StubWeightBiasLayer): | |||||
""" | |||||
StubConv Module. Conv. | |||||
""" | |||||
def __init__(self, input_channel, filters, kernel_size, | |||||
stride=1, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
self.input_channel = input_channel | |||||
self.filters = filters | |||||
self.kernel_size = kernel_size | |||||
self.stride = stride | |||||
self.padding = int(self.kernel_size / 2) | |||||
@property | |||||
def output_shape(self): | |||||
ret = list(self.input.shape[:-1]) | |||||
for index, dim in enumerate(ret): | |||||
ret[index] = ( | |||||
int((dim + 2 * self.padding - self.kernel_size) / self.stride) + 1 | |||||
) | |||||
ret = ret + [self.filters] | |||||
return tuple(ret) | |||||
def size(self): | |||||
return (self.input_channel * self.kernel_size * | |||||
self.kernel_size + 1) * self.filters | |||||
@abstractmethod | |||||
def to_real_layer(self): | |||||
pass | |||||
def __str__(self): | |||||
return ( | |||||
super().__str__() | |||||
+ "(" | |||||
+ ", ".join( | |||||
str(item) | |||||
for item in [ | |||||
self.input_channel, | |||||
self.filters, | |||||
self.kernel_size, | |||||
self.stride, | |||||
] | |||||
) | |||||
+ ")" | |||||
) | |||||
class StubConv1d(StubConv): | |||||
""" | |||||
StubConv1d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Conv1d( | |||||
self.input_channel, | |||||
self.filters, | |||||
self.kernel_size, | |||||
stride=self.stride, | |||||
padding=self.padding, | |||||
) | |||||
class StubConv2d(StubConv): | |||||
""" | |||||
StubConv2d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Conv2d( | |||||
self.input_channel, | |||||
self.filters, | |||||
self.kernel_size, | |||||
stride=self.stride, | |||||
padding=self.padding, | |||||
) | |||||
class StubConv3d(StubConv): | |||||
""" | |||||
StubConv3d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Conv3d( | |||||
self.input_channel, | |||||
self.filters, | |||||
self.kernel_size, | |||||
stride=self.stride, | |||||
padding=self.padding, | |||||
) | |||||
class StubAggregateLayer(StubLayer): | |||||
""" | |||||
StubAggregateLayer Module. | |||||
""" | |||||
def __init__(self, input_nodes=None, output_node=None): | |||||
if input_nodes is None: | |||||
input_nodes = [] | |||||
super().__init__(input_nodes, output_node) | |||||
class StubConcatenate(StubAggregateLayer): | |||||
"""StubConcatenate Module. | |||||
""" | |||||
@property | |||||
def output_shape(self): | |||||
ret = 0 | |||||
for current_input in self.input: | |||||
ret += current_input.shape[-1] | |||||
ret = self.input[0].shape[:-1] + (ret,) | |||||
return ret | |||||
def to_real_layer(self): | |||||
return TorchConcatenate() | |||||
class StubAdd(StubAggregateLayer): | |||||
""" | |||||
StubAdd Module. | |||||
""" | |||||
@property | |||||
def output_shape(self): | |||||
return self.input[0].shape | |||||
def to_real_layer(self): | |||||
return TorchAdd() | |||||
class StubFlatten(StubLayer): | |||||
""" | |||||
StubFlatten Module. | |||||
""" | |||||
@property | |||||
def output_shape(self): | |||||
ret = 1 | |||||
for dim in self.input.shape: | |||||
ret *= dim | |||||
return (ret,) | |||||
def to_real_layer(self): | |||||
return TorchFlatten() | |||||
class StubReLU(StubLayer): | |||||
""" | |||||
StubReLU Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.ReLU() | |||||
class StubSoftmax(StubLayer): | |||||
""" | |||||
StubSoftmax Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.LogSoftmax(dim=1) | |||||
class StubDropout(StubLayer): | |||||
""" | |||||
StubDropout Module. | |||||
""" | |||||
def __init__(self, rate, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
self.rate = rate | |||||
@abstractmethod | |||||
def to_real_layer(self): | |||||
pass | |||||
class StubDropout1d(StubDropout): | |||||
""" | |||||
StubDropout1d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Dropout(self.rate) | |||||
class StubDropout2d(StubDropout): | |||||
""" | |||||
StubDropout2d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Dropout2d(self.rate) | |||||
class StubDropout3d(StubDropout): | |||||
""" | |||||
StubDropout3d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.Dropout3d(self.rate) | |||||
class StubInput(StubLayer): | |||||
""" | |||||
StubInput Module. | |||||
""" | |||||
def __init__(self, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
class StubPooling(StubLayer): | |||||
""" | |||||
StubPooling Module. | |||||
""" | |||||
def __init__(self, | |||||
kernel_size=None, | |||||
stride=None, | |||||
padding=0, | |||||
input_node=None, | |||||
output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
self.kernel_size = ( | |||||
kernel_size if kernel_size is not None else Constant.POOLING_KERNEL_SIZE | |||||
) | |||||
self.stride = stride if stride is not None else self.kernel_size | |||||
self.padding = padding | |||||
@property | |||||
def output_shape(self): | |||||
ret = tuple() | |||||
for dim in self.input.shape[:-1]: | |||||
ret = ret + (max(int((dim + 2 * self.padding) / self.kernel_size), 1),) | |||||
ret = ret + (self.input.shape[-1],) | |||||
return ret | |||||
@abstractmethod | |||||
def to_real_layer(self): | |||||
pass | |||||
class StubPooling1d(StubPooling): | |||||
""" | |||||
StubPooling1d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.MaxPool1d(self.kernel_size, stride=self.stride) | |||||
class StubPooling2d(StubPooling): | |||||
""" | |||||
StubPooling2d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.MaxPool2d(self.kernel_size, stride=self.stride) | |||||
class StubPooling3d(StubPooling): | |||||
""" | |||||
StubPooling3d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return torch.nn.MaxPool3d(self.kernel_size, stride=self.stride) | |||||
class StubGlobalPooling(StubLayer): | |||||
""" | |||||
StubGlobalPooling Module. | |||||
""" | |||||
def __init__(self, input_node=None, output_node=None): | |||||
super().__init__(input_node, output_node) | |||||
@property | |||||
def output_shape(self): | |||||
return (self.input.shape[-1],) | |||||
@abstractmethod | |||||
def to_real_layer(self): | |||||
pass | |||||
class StubGlobalPooling1d(StubGlobalPooling): | |||||
""" | |||||
StubGlobalPooling1d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return GlobalAvgPool1d() | |||||
class StubGlobalPooling2d(StubGlobalPooling): | |||||
""" | |||||
StubGlobalPooling2d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return GlobalAvgPool2d() | |||||
class StubGlobalPooling3d(StubGlobalPooling): | |||||
""" | |||||
StubGlobalPooling3d Module. | |||||
""" | |||||
def to_real_layer(self): | |||||
return GlobalAvgPool3d() | |||||
class TorchConcatenate(nn.Module): | |||||
""" | |||||
TorchConcatenate Module. | |||||
""" | |||||
def forward(self, input_list): | |||||
return torch.cat(input_list, dim=1) | |||||
class TorchAdd(nn.Module): | |||||
""" | |||||
TorchAdd Module. | |||||
""" | |||||
def forward(self, input_list): | |||||
return input_list[0] + input_list[1] | |||||
class TorchFlatten(nn.Module): | |||||
""" | |||||
TorchFlatten Module. | |||||
""" | |||||
def forward(self, input_tensor): | |||||
return input_tensor.view(input_tensor.size(0), -1) | |||||
def is_layer(layer, layer_type): | |||||
""" | |||||
Judge the layer type. | |||||
Returns | |||||
------- | |||||
bool | |||||
boolean -- True or False | |||||
""" | |||||
if layer_type == "Input": | |||||
return isinstance(layer, StubInput) | |||||
elif layer_type == "Conv": | |||||
return isinstance(layer, StubConv) | |||||
elif layer_type == "Dense": | |||||
return isinstance(layer, (StubDense,)) | |||||
elif layer_type == "BatchNormalization": | |||||
return isinstance(layer, (StubBatchNormalization,)) | |||||
elif layer_type == "Concatenate": | |||||
return isinstance(layer, (StubConcatenate,)) | |||||
elif layer_type == "Add": | |||||
return isinstance(layer, (StubAdd,)) | |||||
elif layer_type == "Pooling": | |||||
return isinstance(layer, StubPooling) | |||||
elif layer_type == "Dropout": | |||||
return isinstance(layer, (StubDropout,)) | |||||
elif layer_type == "Softmax": | |||||
return isinstance(layer, (StubSoftmax,)) | |||||
elif layer_type == "ReLU": | |||||
return isinstance(layer, (StubReLU,)) | |||||
elif layer_type == "Flatten": | |||||
return isinstance(layer, (StubFlatten,)) | |||||
elif layer_type == "GlobalAveragePooling": | |||||
return isinstance(layer, StubGlobalPooling) | |||||
return None # note: this is not written by original author, feel free to modify if you think it's incorrect | |||||
def layer_description_extractor(layer, node_to_id): | |||||
""" | |||||
Get layer description. | |||||
""" | |||||
layer_input = layer.input | |||||
layer_output = layer.output | |||||
if layer_input is not None: | |||||
if isinstance(layer_input, Iterable): | |||||
layer_input = list(map(lambda x: node_to_id[x], layer_input)) | |||||
else: | |||||
layer_input = node_to_id[layer_input] | |||||
if layer_output is not None: | |||||
layer_output = node_to_id[layer_output] | |||||
if isinstance(layer, StubConv): | |||||
return ( | |||||
type(layer).__name__, | |||||
layer_input, | |||||
layer_output, | |||||
layer.input_channel, | |||||
layer.filters, | |||||
layer.kernel_size, | |||||
layer.stride, | |||||
layer.padding, | |||||
) | |||||
elif isinstance(layer, (StubDense,)): | |||||
return [ | |||||
type(layer).__name__, | |||||
layer_input, | |||||
layer_output, | |||||
layer.input_units, | |||||
layer.units, | |||||
] | |||||
elif isinstance(layer, (StubBatchNormalization,)): | |||||
return (type(layer).__name__, layer_input, | |||||
layer_output, layer.num_features) | |||||
elif isinstance(layer, (StubDropout,)): | |||||
return (type(layer).__name__, layer_input, layer_output, layer.rate) | |||||
elif isinstance(layer, StubPooling): | |||||
return ( | |||||
type(layer).__name__, | |||||
layer_input, | |||||
layer_output, | |||||
layer.kernel_size, | |||||
layer.stride, | |||||
layer.padding, | |||||
) | |||||
else: | |||||
return (type(layer).__name__, layer_input, layer_output) | |||||
def layer_description_builder(layer_information, id_to_node): | |||||
"""build layer from description. | |||||
""" | |||||
layer_type = layer_information[0] | |||||
layer_input_ids = layer_information[1] | |||||
if isinstance(layer_input_ids, Iterable): | |||||
layer_input = list(map(lambda x: id_to_node[x], layer_input_ids)) | |||||
else: | |||||
layer_input = id_to_node[layer_input_ids] | |||||
layer_output = id_to_node[layer_information[2]] | |||||
if layer_type.startswith("StubConv"): | |||||
input_channel = layer_information[3] | |||||
filters = layer_information[4] | |||||
kernel_size = layer_information[5] | |||||
stride = layer_information[6] | |||||
return globals()[layer_type]( | |||||
input_channel, filters, kernel_size, stride, layer_input, layer_output | |||||
) | |||||
elif layer_type.startswith("StubDense"): | |||||
input_units = layer_information[3] | |||||
units = layer_information[4] | |||||
return globals()[layer_type](input_units, units, layer_input, layer_output) | |||||
elif layer_type.startswith("StubBatchNormalization"): | |||||
num_features = layer_information[3] | |||||
return globals()[layer_type](num_features, layer_input, layer_output) | |||||
elif layer_type.startswith("StubDropout"): | |||||
rate = layer_information[3] | |||||
return globals()[layer_type](rate, layer_input, layer_output) | |||||
elif layer_type.startswith("StubPooling"): | |||||
kernel_size = layer_information[3] | |||||
stride = layer_information[4] | |||||
padding = layer_information[5] | |||||
return globals()[layer_type](kernel_size, stride, padding, layer_input, layer_output) | |||||
else: | |||||
return globals()[layer_type](layer_input, layer_output) | |||||
def layer_width(layer): | |||||
""" | |||||
Get layer width. | |||||
""" | |||||
if is_layer(layer, "Dense"): | |||||
return layer.units | |||||
if is_layer(layer, "Conv"): | |||||
return layer.filters | |||||
raise TypeError("The layer should be either Dense or Conv layer.") | |||||
def set_torch_weight_to_stub(torch_layer, stub_layer): | |||||
stub_layer.import_weights(torch_layer) | |||||
def set_stub_weight_to_torch(stub_layer, torch_layer): | |||||
stub_layer.export_weights(torch_layer) | |||||
def get_conv_class(n_dim): | |||||
conv_class_list = [StubConv1d, StubConv2d, StubConv3d] | |||||
return conv_class_list[n_dim - 1] | |||||
def get_dropout_class(n_dim): | |||||
dropout_class_list = [StubDropout1d, StubDropout2d, StubDropout3d] | |||||
return dropout_class_list[n_dim - 1] | |||||
def get_global_avg_pooling_class(n_dim): | |||||
global_avg_pooling_class_list = [ | |||||
StubGlobalPooling1d, | |||||
StubGlobalPooling2d, | |||||
StubGlobalPooling3d, | |||||
] | |||||
return global_avg_pooling_class_list[n_dim - 1] | |||||
def get_pooling_class(n_dim): | |||||
pooling_class_list = [StubPooling1d, StubPooling2d, StubPooling3d] | |||||
return pooling_class_list[n_dim - 1] | |||||
def get_batch_norm_class(n_dim): | |||||
batch_norm_class_list = [ | |||||
StubBatchNormalization1d, | |||||
StubBatchNormalization2d, | |||||
StubBatchNormalization3d, | |||||
] | |||||
return batch_norm_class_list[n_dim - 1] | |||||
def get_n_dim(layer): | |||||
if isinstance(layer, ( | |||||
StubConv1d, | |||||
StubDropout1d, | |||||
StubGlobalPooling1d, | |||||
StubPooling1d, | |||||
StubBatchNormalization1d, | |||||
)): | |||||
return 1 | |||||
if isinstance(layer, ( | |||||
StubConv2d, | |||||
StubDropout2d, | |||||
StubGlobalPooling2d, | |||||
StubPooling2d, | |||||
StubBatchNormalization2d, | |||||
)): | |||||
return 2 | |||||
if isinstance(layer, ( | |||||
StubConv3d, | |||||
StubDropout3d, | |||||
StubGlobalPooling3d, | |||||
StubPooling3d, | |||||
StubBatchNormalization3d, | |||||
)): | |||||
return 3 | |||||
return -1 |
@@ -0,0 +1,293 @@ | |||||
import logging | |||||
import os | |||||
import shutil | |||||
from utils import Constant, OptimizeMode | |||||
from .bayesian import BayesianOptimizer | |||||
from .nn import CnnGenerator, ResNetGenerator, MlpGenerator | |||||
from .graph import graph_to_json, json_to_graph | |||||
logger = logging.getLogger(__name__) | |||||
class NetworkMorphismSearcher: | |||||
""" | |||||
NetworkMorphismSearcher is a tuner which using network morphism techniques. | |||||
Attributes | |||||
---------- | |||||
n_classes : int | |||||
The class number or output node number (default: ``10``) | |||||
input_shape : tuple | |||||
A tuple including: (input_width, input_width, input_channel) | |||||
t_min : float | |||||
The minimum temperature for simulated annealing. (default: ``Constant.T_MIN``) | |||||
beta : float | |||||
The beta in acquisition function. (default: ``Constant.BETA``) | |||||
algorithm_name : str | |||||
algorithm name used in the network morphism (default: ``"Bayesian"``) | |||||
optimize_mode : str | |||||
optimize mode "minimize" or "maximize" (default: ``"minimize"``) | |||||
verbose : bool | |||||
verbose to print the log (default: ``True``) | |||||
bo : BayesianOptimizer | |||||
The optimizer used in networkmorphsim tuner. | |||||
max_model_size : int | |||||
max model size to the graph (default: ``Constant.MAX_MODEL_SIZE``) | |||||
default_model_len : int | |||||
default model length (default: ``Constant.MODEL_LEN``) | |||||
default_model_width : int | |||||
default model width (default: ``Constant.MODEL_WIDTH``) | |||||
search_space : dict | |||||
""" | |||||
def __init__( | |||||
self, | |||||
path, | |||||
best_selected_space_path, | |||||
task="cv", | |||||
input_width=32, | |||||
input_channel=3, | |||||
n_output_node=10, | |||||
algorithm_name="Bayesian", | |||||
optimize_mode="maximize", | |||||
verbose=True, | |||||
beta=Constant.BETA, | |||||
t_min=Constant.T_MIN, | |||||
max_model_size=Constant.MAX_MODEL_SIZE, | |||||
default_model_len=Constant.MODEL_LEN, | |||||
default_model_width=Constant.MODEL_WIDTH, | |||||
): | |||||
""" | |||||
initilizer of the NetworkMorphismSearcher. | |||||
""" | |||||
self.path = path | |||||
self.best_selected_space_path = best_selected_space_path | |||||
if task == "cv": | |||||
self.generators = [CnnGenerator] | |||||
elif task == "common": | |||||
self.generators = [MlpGenerator] | |||||
else: | |||||
raise NotImplementedError( | |||||
'{} task not supported in List ["cv","common"]') | |||||
self.n_classes = n_output_node | |||||
self.input_shape = (input_width, input_width, input_channel) | |||||
self.t_min = t_min | |||||
self.beta = beta | |||||
self.algorithm_name = algorithm_name | |||||
self.optimize_mode = OptimizeMode(optimize_mode) | |||||
self.json = None | |||||
self.total_data = {} | |||||
self.verbose = verbose | |||||
self.bo = BayesianOptimizer( | |||||
self, self.t_min, self.optimize_mode, self.beta) | |||||
self.training_queue = [] | |||||
self.descriptors = [] | |||||
self.history = [] | |||||
self.max_model_size = max_model_size | |||||
self.default_model_len = default_model_len | |||||
self.default_model_width = default_model_width | |||||
def search(self, parameter_id, args): | |||||
""" | |||||
Returns a set of trial neural architecture, as a serializable object. | |||||
Parameters | |||||
---------- | |||||
parameter_id : int | |||||
""" | |||||
if not self.history: | |||||
self.init_search(args) | |||||
new_father_id = None | |||||
generated_graph = None | |||||
# 先看training queue里面有没有元素,有就是init_search了 | |||||
# 如果有history了话那就生成一个 | |||||
if not self.training_queue: | |||||
new_father_id, generated_graph = self.generate() | |||||
new_model_id = args.trial_id | |||||
self.training_queue.append( | |||||
(generated_graph, new_father_id, new_model_id)) | |||||
self.descriptors.append(generated_graph.extract_descriptor()) | |||||
graph, father_id, model_id = self.training_queue.pop(0) | |||||
# from graph to json | |||||
json_out = graph_to_json(graph, os.path.join(self.path, str(model_id),'model_selected_space.json')) | |||||
self.total_data[parameter_id] = (json_out, father_id, model_id) | |||||
return json_out | |||||
def update_searcher(self, parameter_id, value, **kwargs): | |||||
""" | |||||
Record an observation of the objective function. | |||||
Parameters | |||||
---------- | |||||
parameter_id : int | |||||
the id of a group of paramters that generated by nni manager. | |||||
value : dict/float | |||||
if value is dict, it should have "default" key. | |||||
""" | |||||
if parameter_id not in self.total_data: | |||||
raise RuntimeError("Received parameter_id not in total_data.") | |||||
(_, father_id, model_id) = self.total_data[parameter_id] | |||||
graph = self.bo.searcher.load_model_by_id(model_id) | |||||
# to use the value and graph | |||||
self.add_model(value, model_id) | |||||
self.update(father_id, graph, value, model_id) | |||||
def init_search(self,args): | |||||
""" | |||||
Call the generators to generate the initial architectures for the search. | |||||
""" | |||||
if self.verbose: | |||||
logger.info("Initializing search.") | |||||
for generator in self.generators: | |||||
graph = generator(self.n_classes, self.input_shape).generate( | |||||
self.default_model_len, self.default_model_width | |||||
) | |||||
model_id = args.trial_id | |||||
self.training_queue.append((graph, -1, model_id)) | |||||
self.descriptors.append(graph.extract_descriptor()) | |||||
if self.verbose: | |||||
logger.info("Initialization finished.") | |||||
def generate(self): | |||||
""" | |||||
Generate the next neural architecture. | |||||
Returns | |||||
------- | |||||
other_info : any object | |||||
Anything to be saved in the training queue together with the architecture. | |||||
generated_graph : Graph | |||||
An instance of Graph. | |||||
""" | |||||
generated_graph, new_father_id = self.bo.generate(self.descriptors) | |||||
if new_father_id is None: | |||||
new_father_id = 0 | |||||
generated_graph = self.generators[0]( | |||||
self.n_classes, self.input_shape | |||||
).generate(self.default_model_len, self.default_model_width) | |||||
return new_father_id, generated_graph | |||||
def update(self, other_info, graph, metric_value, model_id): | |||||
""" | |||||
Update the controller with evaluation result of a neural architecture. | |||||
Parameters | |||||
---------- | |||||
other_info: any object | |||||
In our case it is the father ID in the search tree. | |||||
graph: Graph | |||||
An instance of Graph. The trained neural architecture. | |||||
metric_value: float | |||||
The final evaluated metric value. | |||||
model_id: int | |||||
""" | |||||
father_id = other_info | |||||
self.bo.fit([graph.extract_descriptor()], [metric_value]) | |||||
self.bo.add_child(father_id, model_id) | |||||
def add_model(self, metric_value, model_id): | |||||
""" | |||||
Add model to the history, x_queue and y_queue | |||||
Parameters | |||||
---------- | |||||
metric_value : float | |||||
graph : dict | |||||
model_id : int | |||||
Returns | |||||
------- | |||||
model : dict | |||||
""" | |||||
if self.verbose: | |||||
logger.info("Saving model.") | |||||
# Update best_model text file | |||||
ret = {"model_id": model_id, "metric_value": metric_value} | |||||
self.history.append(ret) | |||||
# update best selected space | |||||
if model_id == self.get_best_model_id(): | |||||
best_model_path = os.path.join(self.path, str(model_id),'model_selected_space.json') | |||||
shutil.copy(best_model_path, self.best_selected_space_path) | |||||
return ret | |||||
def get_best_model_id(self): | |||||
""" | |||||
Get the best model_id from history using the metric value | |||||
""" | |||||
if self.optimize_mode is OptimizeMode.Maximize: | |||||
return max(self.history, key=lambda x: x["metric_value"])[ | |||||
"model_id"] | |||||
return min(self.history, key=lambda x: x["metric_value"])["model_id"] | |||||
def load_model_by_id(self, model_id): | |||||
""" | |||||
Get the model by model_id | |||||
Parameters | |||||
---------- | |||||
model_id : int | |||||
model index | |||||
Returns | |||||
------- | |||||
load_model : Graph | |||||
the model graph representation | |||||
""" | |||||
with open(os.path.join(self.path, str(model_id), "model_selected_space.json")) as fin: | |||||
json_str = fin.read().replace("\n", "") | |||||
load_model = json_to_graph(json_str) | |||||
return load_model | |||||
def load_best_model(self): | |||||
""" | |||||
Get the best model by model id | |||||
Returns | |||||
------- | |||||
load_model : Graph | |||||
the model graph representation | |||||
""" | |||||
return self.load_model_by_id(self.get_best_model_id()) | |||||
def get_metric_value_by_id(self, model_id): | |||||
""" | |||||
Get the model metric valud by its model_id | |||||
Parameters | |||||
---------- | |||||
model_id : int | |||||
model index | |||||
Returns | |||||
------- | |||||
float | |||||
the model metric | |||||
""" | |||||
for item in self.history: | |||||
if item["model_id"] == model_id: | |||||
return item["metric_value"] | |||||
return None |
@@ -0,0 +1,227 @@ | |||||
from abc import abstractmethod | |||||
from .graph import Graph | |||||
from .layers import (StubAdd, StubDense, StubDropout1d, | |||||
StubReLU, get_batch_norm_class, | |||||
get_conv_class, | |||||
get_dropout_class, | |||||
get_global_avg_pooling_class, | |||||
get_pooling_class) | |||||
from utils import Constant | |||||
class NetworkGenerator: | |||||
"""The base class for generating a network. | |||||
It can be used to generate a CNN or Multi-Layer Perceptron. | |||||
Attributes: | |||||
n_output_node: Number of output nodes in the network. | |||||
input_shape: A tuple to represent the input shape. | |||||
""" | |||||
def __init__(self, n_output_node, input_shape): | |||||
self.n_output_node = n_output_node | |||||
self.input_shape = input_shape | |||||
@abstractmethod | |||||
def generate(self, model_len, model_width): | |||||
pass | |||||
class CnnGenerator(NetworkGenerator): | |||||
"""A class to generate CNN. | |||||
Attributes: | |||||
n_dim: `len(self.input_shape) - 1` | |||||
conv: A class that represents `(n_dim-1)` dimensional convolution. | |||||
dropout: A class that represents `(n_dim-1)` dimensional dropout. | |||||
global_avg_pooling: A class that represents `(n_dim-1)` dimensional Global Average Pooling. | |||||
pooling: A class that represents `(n_dim-1)` dimensional pooling. | |||||
batch_norm: A class that represents `(n_dim-1)` dimensional batch normalization. | |||||
""" | |||||
def __init__(self, n_output_node, input_shape): | |||||
super(CnnGenerator, self).__init__(n_output_node, input_shape) | |||||
self.n_dim = len(self.input_shape) - 1 | |||||
if len(self.input_shape) > 4: | |||||
raise ValueError("The input dimension is too high.") | |||||
if len(self.input_shape) < 2: | |||||
raise ValueError("The input dimension is too low.") | |||||
self.conv = get_conv_class(self.n_dim) | |||||
self.dropout = get_dropout_class(self.n_dim) | |||||
self.global_avg_pooling = get_global_avg_pooling_class(self.n_dim) | |||||
self.pooling = get_pooling_class(self.n_dim) | |||||
self.batch_norm = get_batch_norm_class(self.n_dim) | |||||
def generate(self, model_len=None, model_width=None): | |||||
"""Generates a CNN. | |||||
Args: | |||||
model_len: An integer. Number of convolutional layers. | |||||
model_width: An integer. Number of filters for the convolutional layers. | |||||
Returns: | |||||
An instance of the class Graph. Represents the neural architecture graph of the generated model. | |||||
""" | |||||
if model_len is None: | |||||
model_len = Constant.MODEL_LEN | |||||
if model_width is None: | |||||
model_width = Constant.MODEL_WIDTH | |||||
pooling_len = int(model_len / 4) | |||||
graph = Graph(self.input_shape, False) | |||||
temp_input_channel = self.input_shape[-1] | |||||
output_node_id = 0 | |||||
stride = 1 | |||||
for i in range(model_len): | |||||
output_node_id = graph.add_layer(StubReLU(), output_node_id) | |||||
output_node_id = graph.add_layer( | |||||
self.batch_norm( | |||||
graph.node_list[output_node_id].shape[-1]), output_node_id | |||||
) | |||||
output_node_id = graph.add_layer( | |||||
self.conv( | |||||
temp_input_channel, | |||||
model_width, | |||||
kernel_size=3, | |||||
stride=stride), | |||||
output_node_id, | |||||
) | |||||
temp_input_channel = model_width | |||||
if pooling_len == 0 or ( | |||||
(i + 1) % pooling_len == 0 and i != model_len - 1): | |||||
output_node_id = graph.add_layer( | |||||
self.pooling(), output_node_id) | |||||
output_node_id = graph.add_layer( | |||||
self.global_avg_pooling(), output_node_id) | |||||
output_node_id = graph.add_layer( | |||||
self.dropout(Constant.CONV_DROPOUT_RATE), output_node_id | |||||
) | |||||
output_node_id = graph.add_layer( | |||||
StubDense(graph.node_list[output_node_id].shape[0], model_width), | |||||
output_node_id, | |||||
) | |||||
output_node_id = graph.add_layer(StubReLU(), output_node_id) | |||||
graph.add_layer( | |||||
StubDense( | |||||
model_width, | |||||
self.n_output_node), | |||||
output_node_id) | |||||
return graph | |||||
class ResNetGenerator(NetworkGenerator): | |||||
def __init__(self, n_output_node, input_shape): | |||||
super(ResNetGenerator, self).__init__(n_output_node, input_shape) | |||||
# self.layers = [2, 2, 2, 2] | |||||
self.in_planes = 64 | |||||
self.block_expansion = 1 | |||||
self.n_dim = len(self.input_shape) - 1 | |||||
if len(self.input_shape) > 4: | |||||
raise ValueError('The input dimension is too high.') | |||||
elif len(self.input_shape) < 2: | |||||
raise ValueError('The input dimension is too low.') | |||||
self.conv = get_conv_class(self.n_dim) | |||||
self.dropout = get_dropout_class(self.n_dim) | |||||
self.global_avg_pooling = get_global_avg_pooling_class(self.n_dim) | |||||
self.adaptive_avg_pooling = get_global_avg_pooling_class(self.n_dim) | |||||
self.batch_norm = get_batch_norm_class(self.n_dim) | |||||
def generate(self, model_len=None, model_width=None): | |||||
if model_width is None: | |||||
model_width = Constant.MODEL_WIDTH | |||||
graph = Graph(self.input_shape, False) | |||||
temp_input_channel = self.input_shape[-1] | |||||
output_node_id = 0 | |||||
# output_node_id = graph.add_layer(StubReLU(), output_node_id) | |||||
output_node_id = graph.add_layer(self.conv(temp_input_channel, model_width, kernel_size=3), output_node_id) | |||||
output_node_id = graph.add_layer(self.batch_norm(model_width), output_node_id) | |||||
# output_node_id = graph.add_layer(self.pooling(kernel_size=3, stride=2, padding=1), output_node_id) | |||||
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 1) | |||||
model_width *= 2 | |||||
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) | |||||
model_width *= 2 | |||||
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) | |||||
model_width *= 2 | |||||
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2) | |||||
output_node_id = graph.add_layer(self.global_avg_pooling(), output_node_id) | |||||
graph.add_layer(StubDense(model_width * self.block_expansion, self.n_output_node), output_node_id) | |||||
return graph | |||||
def _make_layer(self, graph, planes, blocks, node_id, stride): | |||||
strides = [stride] + [1] * (blocks - 1) | |||||
out = node_id | |||||
for current_stride in strides: | |||||
out = self._make_block(graph, self.in_planes, planes, out, current_stride) | |||||
self.in_planes = planes * self.block_expansion | |||||
return out | |||||
def _make_block(self, graph, in_planes, planes, node_id, stride=1): | |||||
out = graph.add_layer(self.batch_norm(in_planes), node_id) | |||||
out = graph.add_layer(StubReLU(), out) | |||||
residual_node_id = out | |||||
out = graph.add_layer(self.conv(in_planes, planes, kernel_size=3, stride=stride), out) | |||||
out = graph.add_layer(self.batch_norm(planes), out) | |||||
out = graph.add_layer(StubReLU(), out) | |||||
out = graph.add_layer(self.conv(planes, planes, kernel_size=3), out) | |||||
residual_node_id = graph.add_layer(StubReLU(), residual_node_id) | |||||
residual_node_id = graph.add_layer(self.conv(in_planes, | |||||
planes * self.block_expansion, | |||||
kernel_size=1, | |||||
stride=stride), residual_node_id) | |||||
out = graph.add_layer(StubAdd(), (out, residual_node_id)) | |||||
return out | |||||
class MlpGenerator(NetworkGenerator): | |||||
"""A class to generate Multi-Layer Perceptron. | |||||
""" | |||||
def __init__(self, n_output_node, input_shape): | |||||
"""Initialize the instance. | |||||
Args: | |||||
n_output_node: An integer. Number of output nodes in the network. | |||||
input_shape: A tuple. Input shape of the network. If it is 1D, ensure the value is appended by a comma | |||||
in the tuple. | |||||
""" | |||||
super(MlpGenerator, self).__init__(n_output_node, input_shape) | |||||
if len(self.input_shape) > 1: | |||||
raise ValueError("The input dimension is too high.") | |||||
def generate(self, model_len=None, model_width=None): | |||||
"""Generates a Multi-Layer Perceptron. | |||||
Args: | |||||
model_len: An integer. Number of hidden layers. | |||||
model_width: An integer or a list of integers of length `model_len`. If it is a list, it represents the | |||||
number of nodes in each hidden layer. If it is an integer, all hidden layers have nodes equal to this | |||||
value. | |||||
Returns: | |||||
An instance of the class Graph. Represents the neural architecture graph of the generated model. | |||||
""" | |||||
if model_len is None: | |||||
model_len = Constant.MODEL_LEN | |||||
if model_width is None: | |||||
model_width = Constant.MODEL_WIDTH | |||||
if isinstance(model_width, list) and not len(model_width) == model_len: | |||||
raise ValueError( | |||||
"The length of 'model_width' does not match 'model_len'") | |||||
elif isinstance(model_width, int): | |||||
model_width = [model_width] * model_len | |||||
graph = Graph(self.input_shape, False) | |||||
output_node_id = 0 | |||||
n_nodes_prev_layer = self.input_shape[0] | |||||
for width in model_width: | |||||
output_node_id = graph.add_layer( | |||||
StubDense(n_nodes_prev_layer, width), output_node_id | |||||
) | |||||
output_node_id = graph.add_layer( | |||||
StubDropout1d(Constant.MLP_DROPOUT_RATE), output_node_id | |||||
) | |||||
output_node_id = graph.add_layer(StubReLU(), output_node_id) | |||||
n_nodes_prev_layer = width | |||||
graph.add_layer( | |||||
StubDense( | |||||
n_nodes_prev_layer, | |||||
self.n_output_node), | |||||
output_node_id) | |||||
return graph |
@@ -0,0 +1,64 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torchvision.transforms as transforms | |||||
from utils import Constant | |||||
class Cutout: | |||||
"""Randomly mask out one or more patches from an image. | |||||
Args: | |||||
n_holes (int): Number of patches to cut out of each image. | |||||
length (int): The length (in pixels) of each square patch. | |||||
""" | |||||
def __init__(self, n_holes, length): | |||||
self.n_holes = n_holes | |||||
self.length = length | |||||
def __call__(self, img): | |||||
""" | |||||
Args: | |||||
img (Tensor): Tensor image of size (C, H, W). | |||||
Returns: | |||||
Tensor: Image with n_holes of dimension length x length cut out of it. | |||||
""" | |||||
h, w = img.size(1), img.size(2) | |||||
mask = np.ones((h, w), np.float32) | |||||
for _ in range(self.n_holes): | |||||
y = np.random.randint(h) | |||||
x = np.random.randint(w) | |||||
y1 = np.clip(y - self.length // 2, 0, h) | |||||
y2 = np.clip(y + self.length // 2, 0, h) | |||||
x1 = np.clip(x - self.length // 2, 0, w) | |||||
x2 = np.clip(x + self.length // 2, 0, w) | |||||
mask[y1:y2, x1:x2] = 0.0 | |||||
mask = torch.from_numpy(mask) | |||||
mask = mask.expand_as(img) | |||||
img *= mask | |||||
return img | |||||
def data_transforms_cifar10(): | |||||
""" data_transforms for cifar10 dataset | |||||
""" | |||||
cifar_mean = [0.49139968, 0.48215827, 0.44653124] | |||||
cifar_std = [0.24703233, 0.24348505, 0.26158768] | |||||
train_transform = transforms.Compose( | |||||
[ | |||||
transforms.RandomCrop(32, padding=4), | |||||
transforms.RandomHorizontalFlip(), | |||||
transforms.ToTensor(), | |||||
transforms.Normalize(cifar_mean, cifar_std), | |||||
Cutout(n_holes=Constant.CUTOUT_HOLES, | |||||
length=int(32 * Constant.CUTOUT_RATIO)) | |||||
] | |||||
) | |||||
valid_transform = transforms.Compose( | |||||
[transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)] | |||||
) | |||||
return train_transform, valid_transform |
@@ -0,0 +1,57 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
from pytorch.network_morphism.network_morphism_trainer import NetworkMorphismTrainer | |||||
import argparse | |||||
from pytorch.utils import init_logger, mkdirs | |||||
import json | |||||
logger = logging.getLogger(__name__) | |||||
class Retrain: | |||||
def __init__(self, args): | |||||
self.args = args | |||||
def run(self): | |||||
logger.info("Retraining the best model.") | |||||
with open(args.best_selected_space_path, 'r') as f: | |||||
json_out = json.load(f) | |||||
json_out = json.dumps(json_out) | |||||
trainer = NetworkMorphismTrainer(json_out, self.args) | |||||
trainer.retrain() | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser("network_morphism_retrain") | |||||
parser.add_argument("--trial_id", type=int, default=0, help="Trial id") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='./log', help="log for info") | |||||
parser.add_argument( | |||||
"--experiment_dir", type=str, default='./TADL', help="experiment level path" | |||||
) | |||||
parser.add_argument( | |||||
"--best_selected_space_path", type=str, default='./best_selected_space.json', help="Path to best selected space" | |||||
) | |||||
parser.add_argument( | |||||
"--result_path", type=str, default='./result.json', help="Path to result" | |||||
) | |||||
parser.add_argument( | |||||
"--best_checkpoint_dir", type=str, default='./', help="Path to checkpoint saved" | |||||
) | |||||
parser.add_argument( | |||||
"--data_dir", type=str, default='../data/', help="Path to dataset" | |||||
) | |||||
parser.add_argument("--batch_size", type=int, | |||||
default=128, help="batch size") | |||||
parser.add_argument("--opt", type=str, default="SGD", help="optimizer") | |||||
parser.add_argument("--epochs", type=int, default=200, help="epoch limit") | |||||
parser.add_argument( | |||||
"--lr", type=float, default=0.001, help="learning rate" | |||||
) | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) | |||||
init_logger(args.log_path) | |||||
retrain = Retrain(args) | |||||
retrain.run() |
@@ -0,0 +1,22 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
from argparse import ArgumentParser | |||||
from pytorch.selector import Selector | |||||
class NetworkMorphismSelector(Selector): | |||||
def __init__(self, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
def fit(self): | |||||
pass | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("NetworkMorphism select") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
args = parser.parse_args() | |||||
select = NetworkMorphismSelector(True) | |||||
select.fit() |
@@ -0,0 +1,87 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
from pytorch.network_morphism.network_morphism_trainer import NetworkMorphismTrainer | |||||
from pytorch.network_morphism.algorithm.networkmorphism_searcher import NetworkMorphismSearcher | |||||
import argparse | |||||
import pickle | |||||
from pytorch.utils import init_logger, mkdirs | |||||
logger = logging.getLogger(__name__) | |||||
def create_dir(path): | |||||
if os.path.exists(path): | |||||
# shutil.rmtree(path) | |||||
return path | |||||
os.makedirs(path) | |||||
return path | |||||
class Train: | |||||
def __init__(self, args): | |||||
self.id = args.trial_id | |||||
self.trial_dir = os.path.join( | |||||
args.experiment_dir, 'train', str(args.trial_id)) | |||||
self.searcher_dir = os.path.join( | |||||
args.experiment_dir, '{}.pkl'.format(NetworkMorphismSearcher.__name__)) | |||||
self.args = args | |||||
self.searcher = None | |||||
# first trial | |||||
if not os.path.exists(self.searcher_dir): | |||||
self.searcher = NetworkMorphismSearcher(os.path.join( | |||||
args.experiment_dir, 'train'), args.best_selected_space_path) | |||||
else: | |||||
# load from previous round | |||||
with open(self.searcher_dir, 'rb') as f: | |||||
self.searcher = pickle.load(f) | |||||
def run_trial_job(self): | |||||
logger.info('trial {} search next model'.format(self.id)) | |||||
model = self.searcher.search(self.id,self.args) | |||||
trainer = NetworkMorphismTrainer(model, self.args) | |||||
logger.info('trial {} run training script'.format(self.id)) | |||||
metric = trainer.train() | |||||
if metric != None: | |||||
logger.info('trial {} receive trial result'.format(self.id)) | |||||
self.searcher.update_searcher(self.id, metric) | |||||
with open(self.searcher_dir, 'wb') as f: | |||||
pickle.dump(self.searcher, f) | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser("network_morphism") | |||||
parser.add_argument("--trial_id", type=int, default=0, help="Trial id") | |||||
parser.add_argument( | |||||
"--data_dir", type=str, default='../data/', help="Path to dataset" | |||||
) | |||||
parser.add_argument( | |||||
"--log_path", type=str, default='./log', help="Path to log file" | |||||
) | |||||
parser.add_argument( | |||||
"--experiment_dir", type=str, default='./TADL', help="experiment level path" | |||||
) | |||||
parser.add_argument( | |||||
"--result_path", type=str, default='./result.json', help="trial level path to result" | |||||
) | |||||
parser.add_argument( | |||||
"--search_space_path", type=str, default='./search_space.json', help="experiment level path to search space" | |||||
) | |||||
parser.add_argument( | |||||
"--best_selected_space_path", type=str, default='./best_selected_space.json', help="experiment level path to best selected space" | |||||
) | |||||
parser.add_argument("--batch_size", type=int, | |||||
default=128, help="batch size") | |||||
parser.add_argument("--opt", type=str, default="SGD", help="optimizer") | |||||
parser.add_argument("--epochs", type=int, default=2, help="epoch limit") | |||||
parser.add_argument( | |||||
"--lr", type=float, default=0.001, help="learning rate" | |||||
) | |||||
args = parser.parse_args() | |||||
mkdirs(args.experiment_dir, args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) | |||||
create_dir(os.path.join(args.experiment_dir,'train',str(args.trial_id))) | |||||
init_logger(args.log_path) | |||||
train = Train(args) | |||||
train.run_trial_job() |
@@ -0,0 +1,175 @@ | |||||
import logging | |||||
from algorithm.graph import json_to_graph | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
import torch.utils.data as data | |||||
import torchvision | |||||
import re | |||||
import datasets | |||||
from utils import Constant, EarlyStop, save_json_result | |||||
from pytorch.utils import save_best_checkpoint | |||||
# pylint: disable=W0603 | |||||
# set the logger format | |||||
logger = logging.getLogger(__name__) | |||||
class NetworkMorphismTrainer: | |||||
def __init__(self, model_json, args): | |||||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |||||
self.batch_size = args.batch_size | |||||
self.epochs = args.epochs | |||||
self.lr = args.lr | |||||
self.optimizer_name = args.opt | |||||
self.data_dir = args.data_dir | |||||
self.trial_id = args.trial_id | |||||
self.args = args | |||||
# Loading Data | |||||
logger.info("Preparing data..") | |||||
transform_train, transform_test = datasets.data_transforms_cifar10() | |||||
trainset = torchvision.datasets.CIFAR10( | |||||
root=self.data_dir, train=True, download=True, transform=transform_train | |||||
) | |||||
self.trainloader = data.DataLoader( | |||||
trainset, batch_size=self.batch_size, shuffle=True, num_workers=1 | |||||
) | |||||
testset = torchvision.datasets.CIFAR10( | |||||
root=self.data_dir, train=False, download=True, transform=transform_test | |||||
) | |||||
self.testloader = data.DataLoader( | |||||
testset, batch_size=self.batch_size, shuffle=False, num_workers=1 | |||||
) | |||||
# Model | |||||
logger.info("Building model..") | |||||
# build model from json representation | |||||
self.graph = json_to_graph(model_json) | |||||
self.net = self.graph.produce_torch_model() | |||||
if self.device == "cuda" and torch.cuda.device_count() > 1: | |||||
self.net = nn.DataParallel(self.net) | |||||
self.net.to(self.device) | |||||
self.criterion = nn.CrossEntropyLoss() | |||||
if self.optimizer_name == "SGD": | |||||
self.optimizer = optim.SGD( | |||||
self.net.parameters(), lr=self.lr, momentum=0.9, weight_decay=3e-4 | |||||
) | |||||
if self.optimizer_name == "Adadelta": | |||||
self.optimizer = optim.Adadelta(self.net.parameters(), lr=self.lr) | |||||
if self.optimizer_name == "Adagrad": | |||||
self.optimizer = optim.Adagrad(self.net.parameters(), lr=self.lr) | |||||
if self.optimizer_name == "Adam": | |||||
self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) | |||||
if self.optimizer_name == "Adamax": | |||||
self.optimizer = optim.Adamax(self.net.parameters(), lr=self.lr) | |||||
if self.optimizer_name == "RMSprop": | |||||
self.optimizer = optim.RMSprop(self.net.parameters(), lr=self.lr) | |||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR( | |||||
self.optimizer, self.epochs) | |||||
def train_one_epoch(self): | |||||
""" | |||||
train model on each epoch in trainset | |||||
""" | |||||
self.net.train() | |||||
for batch_idx, (inputs, targets) in enumerate(self.trainloader): | |||||
inputs, targets = inputs.to(self.device), targets.to(self.device) | |||||
self.optimizer.zero_grad() | |||||
outputs = self.net(inputs) | |||||
loss = self.criterion(outputs, targets) | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
def validate_one_epoch(self, epoch): | |||||
""" eval model on each epoch in testset | |||||
""" | |||||
self.net.eval() | |||||
test_loss = 0 | |||||
correct = 0 | |||||
total = 0 | |||||
with torch.no_grad(): | |||||
for batch_idx, (inputs, targets) in enumerate(self.testloader): | |||||
inputs, targets = inputs.to( | |||||
self.device), targets.to(self.device) | |||||
outputs = self.net(inputs) | |||||
loss = self.criterion(outputs, targets) | |||||
test_loss += loss.item() | |||||
_, predicted = outputs.max(1) | |||||
total += targets.size(0) | |||||
correct += predicted.eq(targets).sum().item() | |||||
acc = correct / total | |||||
logger.info("Epoch: %d, accuracy: %.3f", epoch, acc) | |||||
result = {"type": "Accuracy", "result": { | |||||
"sequence": epoch, "category": "epoch", "value": acc}} | |||||
save_json_result(self.args.result_path, result) | |||||
return test_loss, acc | |||||
def train(self): | |||||
try: | |||||
max_no_improvement_num = Constant.MAX_NO_IMPROVEMENT_NUM | |||||
early_stop = EarlyStop(max_no_improvement_num) | |||||
early_stop.on_train_begin() | |||||
test_metric_value_list = [] | |||||
for ep in range(self.epochs): | |||||
self.train_one_epoch() | |||||
test_loss, test_acc = self.validate_one_epoch(ep) | |||||
self.scheduler.step() | |||||
test_metric_value_list.append(test_acc) | |||||
decreasing = early_stop.on_epoch_end(test_loss) | |||||
if not decreasing: | |||||
break | |||||
last_num = min(max_no_improvement_num, self.epochs) | |||||
estimated_performance = sum( | |||||
test_metric_value_list[-last_num:]) / last_num | |||||
logger.info("final accuracy: %.3f", estimated_performance) | |||||
except RuntimeError as e: | |||||
if not re.search('out of memory', str(e)): | |||||
raise e | |||||
print( | |||||
'\nCurrent model size is too big. Discontinuing training this model to search for other models.') | |||||
Constant.MAX_MODEL_SIZE = self.graph.size()-1 | |||||
return None | |||||
except Exception as e: | |||||
logger.exception(e) | |||||
raise | |||||
return estimated_performance | |||||
def retrain(self): | |||||
logger.info("here") | |||||
try: | |||||
best_acc = 0.0 | |||||
for ep in range(self.epochs): | |||||
logger.info(ep) | |||||
self.train_one_epoch() | |||||
_, test_acc = self.validate_one_epoch(ep) | |||||
self.scheduler.step() | |||||
if test_acc > best_acc: | |||||
best_acc = test_acc | |||||
save_best_checkpoint(self.args.best_checkpoint_dir, | |||||
self.net, self.optimizer, self.epochs) | |||||
logger.info("final accuracy: %.3f", best_acc) | |||||
except Exception as exception: | |||||
logger.exception(exception) | |||||
raise |
@@ -0,0 +1,102 @@ | |||||
from enum import Enum | |||||
import json | |||||
class Constant: | |||||
# Data | |||||
CUTOUT_HOLES = 1 | |||||
CUTOUT_RATIO = 0.5 | |||||
# Searcher | |||||
MAX_MODEL_NUM = 1000 | |||||
MAX_LAYERS = 200 | |||||
N_NEIGHBOURS = 8 | |||||
MAX_MODEL_SIZE = (1 << 25) | |||||
MAX_LAYER_WIDTH = 4096 | |||||
KERNEL_LAMBDA = 1.0 | |||||
BETA = 2.576 | |||||
T_MIN = 0.0001 | |||||
MLP_MODEL_LEN = 3 | |||||
MLP_MODEL_WIDTH = 5 | |||||
MODEL_LEN = 3 | |||||
MODEL_WIDTH = 64 | |||||
POOLING_KERNEL_SIZE = 2 | |||||
DENSE_DROPOUT_RATE = 0.5 | |||||
CONV_DROPOUT_RATE = 0.25 | |||||
MLP_DROPOUT_RATE = 0.25 | |||||
CONV_BLOCK_DISTANCE = 2 | |||||
# trainer | |||||
MAX_NO_IMPROVEMENT_NUM = 5 | |||||
MIN_LOSS_DEC = 1e-4 | |||||
class OptimizeMode(Enum): | |||||
"""Optimize Mode class | |||||
if OptimizeMode is 'minimize', it means the tuner need to minimize the reward | |||||
that received from Trial. | |||||
if OptimizeMode is 'maximize', it means the tuner need to maximize the reward | |||||
that received from Trial. | |||||
""" | |||||
Minimize = 'minimize' | |||||
Maximize = 'maximize' | |||||
class EarlyStop: | |||||
"""A class check for early stop condition. | |||||
Attributes: | |||||
training_losses: Record all the training loss. | |||||
minimum_loss: The minimum loss we achieve so far. Used to compared to determine no improvement condition. | |||||
no_improvement_count: Current no improvement count. | |||||
_max_no_improvement_num: The maximum number specified. | |||||
_done: Whether condition met. | |||||
_min_loss_dec: A threshold for loss improvement. | |||||
""" | |||||
def __init__(self, max_no_improvement_num=None, min_loss_dec=None): | |||||
self.training_losses = [] | |||||
self.minimum_loss = None | |||||
self.no_improvement_count = 0 | |||||
self._max_no_improvement_num = max_no_improvement_num if max_no_improvement_num is not None \ | |||||
else Constant.MAX_NO_IMPROVEMENT_NUM | |||||
self._done = False | |||||
self._min_loss_dec = min_loss_dec if min_loss_dec is not None else Constant.MIN_LOSS_DEC | |||||
def on_train_begin(self): | |||||
"""Initiate the early stop condition. | |||||
Call on every time the training iteration begins. | |||||
""" | |||||
self.training_losses = [] | |||||
self.no_improvement_count = 0 | |||||
self._done = False | |||||
self.minimum_loss = float('inf') | |||||
def on_epoch_end(self, loss): | |||||
"""Check the early stop condition. | |||||
Call on every time the training iteration end. | |||||
Args: | |||||
loss: The loss function achieved by the epoch. | |||||
Returns: | |||||
True if condition met, otherwise False. | |||||
""" | |||||
self.training_losses.append(loss) | |||||
if self._done and loss > (self.minimum_loss - self._min_loss_dec): | |||||
return False | |||||
if loss > (self.minimum_loss - self._min_loss_dec): | |||||
self.no_improvement_count += 1 | |||||
else: | |||||
self.no_improvement_count = 0 | |||||
self.minimum_loss = loss | |||||
if self.no_improvement_count > self._max_no_improvement_num: | |||||
self._done = True | |||||
return True | |||||
def save_json_result(path, data): | |||||
with open(path,'a') as f: | |||||
json.dump(data,f) | |||||
f.write('\n') |
@@ -0,0 +1 @@ | |||||
from pytorch.pcdarts.pcdartsmutator import PCdartsMutator |
@@ -0,0 +1,227 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
from pytorch import mutables | |||||
from pytorch.darts import ops | |||||
def random_channel_shuffle(x): | |||||
num_channels = x.data.size()[1] | |||||
indices = torch.randperm(num_channels) | |||||
x = x[:, indices] | |||||
return x | |||||
def channel_shuffle(x, groups): | |||||
batchsize, num_channels, height, width = x.data.size() | |||||
channels_per_group = num_channels // groups | |||||
# reshape | |||||
x = x.view(batchsize, groups, | |||||
channels_per_group, height, width) | |||||
x = torch.transpose(x, 1, 2).contiguous() | |||||
# flatten | |||||
x = x.view(batchsize, -1, height, width) | |||||
return x | |||||
class AuxiliaryHead(nn.Module): | |||||
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ | |||||
def __init__(self, input_size, C, n_classes): | |||||
""" assuming input size 7x7 or 8x8 """ | |||||
assert input_size in [7, 8] | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(inplace=True), | |||||
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out | |||||
nn.Conv2d(C, 128, kernel_size=1, bias=False), | |||||
nn.BatchNorm2d(128), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out | |||||
nn.BatchNorm2d(768), | |||||
nn.ReLU(inplace=True) | |||||
) | |||||
self.linear = nn.Linear(768, n_classes) | |||||
def forward(self, x): | |||||
out = self.net(x) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
return logits | |||||
class Node(nn.Module): | |||||
def __init__(self, node_id, num_prev_nodes, channels, k, num_downsample_connect, search): | |||||
super().__init__() | |||||
if search: | |||||
self.k = k | |||||
partial_channles = channels // k | |||||
else: | |||||
partial_channles = channels | |||||
self.search = search | |||||
self.ops = nn.ModuleList() | |||||
choice_keys = [] | |||||
for i in range(num_prev_nodes): | |||||
stride = 2 if i < num_downsample_connect else 1 | |||||
choice_keys.append("{}_p{}".format(node_id, i)) | |||||
self.ops.append( | |||||
mutables.LayerChoice(OrderedDict([ | |||||
("maxpool", ops.PoolBN('max', partial_channles, 3, stride, 1, affine=False)), | |||||
("avgpool", ops.PoolBN('avg', partial_channles, 3, stride, 1, affine=False)), | |||||
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(partial_channles, partial_channles, affine=False)), | |||||
("sepconv3x3", ops.SepConv(partial_channles, partial_channles, 3, stride, 1, affine=False)), | |||||
("sepconv5x5", ops.SepConv(partial_channles, partial_channles, 5, stride, 2, affine=False)), | |||||
("dilconv3x3", ops.DilConv(partial_channles, partial_channles, 3, stride, 2, 2, affine=False)), | |||||
("dilconv5x5", ops.DilConv(partial_channles, partial_channles, 5, stride, 4, 2, affine=False)) | |||||
]), key=choice_keys[-1])) | |||||
self.drop_path = ops.DropPath() | |||||
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) | |||||
self.pool = nn.MaxPool2d(2,2) | |||||
def forward(self, prev_nodes): | |||||
assert len(self.ops) == len(prev_nodes), "len(self.ops) != len(prev_nodes) in Node" | |||||
# for each candicate predecessor of each intermediate node | |||||
if self.search: | |||||
# in search | |||||
results = [] | |||||
for op, x in zip(self.ops, prev_nodes): | |||||
# channel shuffle | |||||
channels = x.shape[1] | |||||
# channel proportion k=4 | |||||
temp0 = x[ : , : channels//self.k, : , :] | |||||
temp1 = x[ : ,channels//self.k : , : , :] | |||||
out = op(temp0) | |||||
# normal | |||||
if out.shape[2] == x.shape[2]: | |||||
result = torch.cat([out, temp1], dim=1) | |||||
# reduction | |||||
else: | |||||
result = torch.cat([out, self.pool(temp1)], dim=1) | |||||
results.append(channel_shuffle(result, self.k)) | |||||
# # channel random shuffule | |||||
# channels = random_channel_shuffle(x).shape[1] | |||||
# # channel proportion k=4 | |||||
# temp0 = x[ : , : channels//self.k, : , :] | |||||
# temp1 = x[ : ,channels//self.k : , : , :] | |||||
# out = op(temp0) | |||||
# # normal | |||||
# if out.shape[2] == x.shape[2]: | |||||
# result = torch.cat([out, temp1], dim=1) | |||||
# # reduction | |||||
# else: | |||||
# result = torch.cat([out, self.pool(temp1)], dim=1) | |||||
# results.append(result) | |||||
else: | |||||
# in retrain, no channel shuffle | |||||
results = [op(node) for op, node in zip(self.ops, prev_nodes)] | |||||
output = [self.drop_path(re) if re is not None else None for re in results] | |||||
return self.input_switch(output) | |||||
class Cell(nn.Module): | |||||
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, k, search): | |||||
super().__init__() | |||||
self.reduction = reduction | |||||
self.n_nodes = n_nodes | |||||
# If previous cell is reduction cell, current input size does not match with | |||||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. | |||||
if reduction_p: | |||||
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) | |||||
else: | |||||
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) | |||||
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) | |||||
# generate dag | |||||
self.mutable_ops = nn.ModuleList() | |||||
for depth in range(2, self.n_nodes + 2): | |||||
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, k, 2 if reduction else 0, search)) | |||||
def forward(self, s0, s1): | |||||
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. | |||||
tensors = [self.preproc0(s0), self.preproc1(s1)] | |||||
for node in self.mutable_ops: | |||||
cur_tensor = node(tensors) | |||||
tensors.append(cur_tensor) | |||||
output = torch.cat(tensors[2:], dim=1) | |||||
return output | |||||
class CNN(nn.Module): | |||||
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, k=4, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True): | |||||
super().__init__() | |||||
self.in_channels = in_channels | |||||
self.channels = channels | |||||
self.n_classes = n_classes | |||||
self.n_layers = n_layers | |||||
self.n_nodes = n_nodes | |||||
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 | |||||
c_cur = stem_multiplier * self.channels | |||||
self.stem = nn.Sequential( | |||||
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), | |||||
nn.BatchNorm2d(c_cur) | |||||
) | |||||
# for the first cell, stem is used for both s0 and s1 | |||||
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. | |||||
channels_pp, channels_p, c_cur = c_cur, c_cur, channels | |||||
self.cells = nn.ModuleList() | |||||
reduction_p, reduction = False, False | |||||
for i in range(n_layers): | |||||
reduction_p, reduction = reduction, False | |||||
# Reduce featuremap size and double channels in 1/3 and 2/3 layer. | |||||
if i in [n_layers // 3, 2 * n_layers // 3]: | |||||
c_cur *= 2 | |||||
reduction = True | |||||
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, k, search) | |||||
self.cells.append(cell) | |||||
c_cur_out = c_cur * n_nodes | |||||
channels_pp, channels_p = channels_p, c_cur_out | |||||
if i == self.aux_pos: | |||||
self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) | |||||
self.gap = nn.AdaptiveAvgPool2d(1) | |||||
self.linear = nn.Linear(channels_p, n_classes) | |||||
def forward(self, x): | |||||
s0 = s1 = self.stem(x) | |||||
aux_logits = None | |||||
for i, cell in enumerate(self.cells): | |||||
s0, s1 = s1, cell(s0, s1) | |||||
if i == self.aux_pos and self.training: | |||||
aux_logits = self.aux_head(s1) | |||||
out = self.gap(s1) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
if aux_logits is not None: | |||||
return logits, aux_logits | |||||
return logits | |||||
def drop_path_prob(self, p): | |||||
for module in self.modules(): | |||||
if isinstance(module, ops.DropPath): | |||||
module.p = p | |||||
def _loss(self, input, target): | |||||
logits = self(input) | |||||
return self._criterion(logits, target) |
@@ -0,0 +1,204 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
import time | |||||
from argparse import ArgumentParser | |||||
import torch | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
# from torch.utils.tensorboard import SummaryWriter | |||||
import torch.backends.cudnn as cudnn | |||||
from model import CNN | |||||
from pytorch.fixed import apply_fixed_architecture | |||||
from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter | |||||
from pytorch.darts import utils | |||||
from pytorch.darts import datasets | |||||
from pytorch.retrainer import Retrainer | |||||
logger = logging.getLogger(__name__) | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
# writer = SummaryWriter() | |||||
class PCdartsRetrainer(Retrainer): | |||||
def __init__(self, aux_weight, grad_clip, epochs, log_frequency): | |||||
self.aux_weight = aux_weight | |||||
self.grad_clip = grad_clip | |||||
self.epochs = epochs | |||||
self.log_frequency = log_frequency | |||||
def train(self, train_loader, model, optimizer, criterion, epoch): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
cur_step = epoch * len(train_loader) | |||||
cur_lr = optimizer.param_groups[0]["lr"] | |||||
logger.info("Epoch %d LR %.6f", epoch, cur_lr) | |||||
# writer.add_scalar("lr", cur_lr, global_step=cur_step) | |||||
model.train() | |||||
for step, (x, y) in enumerate(train_loader): | |||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = x.size(0) | |||||
optimizer.zero_grad() | |||||
logits, aux_logits = model(x) | |||||
loss = criterion(logits, y) | |||||
if self.aux_weight > 0.: | |||||
loss += self.aux_weight * criterion(aux_logits, y) | |||||
loss.backward() | |||||
# gradient clipping | |||||
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) | |||||
optimizer.step() | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
# writer.add_scalar("loss/train", loss.item(), global_step=cur_step) | |||||
# writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) | |||||
# writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) | |||||
if step % self.log_frequency == 0 or step == len(train_loader) - 1: | |||||
logger.info( | |||||
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
cur_step += 1 | |||||
logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
def validate(self, valid_loader, model, criterion, epoch, cur_step): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
for step, (X, y) in enumerate(valid_loader): | |||||
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = X.size(0) | |||||
logits = model(X) | |||||
loss = criterion(logits, y) | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
if step % self.log_frequency == 0 or step == len(valid_loader) - 1: | |||||
logger.info( | |||||
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
# writer.add_scalar("loss/test", losses.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) | |||||
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
return top1.avg | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("PCDARTS retrain") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='.0/log', help="log for info") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument("--best_checkpoint_dir", type=str, | |||||
default='', help="default name is best_checkpoint_epoch{}.pth") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--layers", default=20, type=int) | |||||
parser.add_argument("--lr", default=0.01, type=float) | |||||
parser.add_argument("--batch_size", default=96, type=int) | |||||
parser.add_argument("--log_frequency", default=10, type=int) | |||||
parser.add_argument("--epochs", default=600, type=int) | |||||
parser.add_argument("--aux_weight", default=0.4, type=float) | |||||
parser.add_argument("--drop_path_prob", default=0.2, type=float) | |||||
parser.add_argument("--workers", default=4, type=int) | |||||
parser.add_argument("--class_num", default=10, type=int, help="cifar10") | |||||
parser.add_argument("--channels", default=36, type=int) | |||||
parser.add_argument("--grad_clip", default=6., type=float) | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) | |||||
init_logger(args.log_path) | |||||
logger.info(args) | |||||
set_seed(args.trial_id) | |||||
logger.info("loading data") | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) | |||||
model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True, search=False) | |||||
apply_fixed_architecture(model, args.best_selected_space_path) | |||||
criterion = nn.CrossEntropyLoss() | |||||
model.to(device) | |||||
criterion.to(device) | |||||
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) | |||||
train_loader = torch.utils.data.DataLoader(dataset_train, | |||||
batch_size=args.batch_size, | |||||
shuffle=True, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
valid_loader = torch.utils.data.DataLoader(dataset_valid, | |||||
batch_size=args.batch_size, | |||||
shuffle=False, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
retrainer = PCdartsRetrainer(aux_weight=args.aux_weight, | |||||
grad_clip=args.grad_clip, | |||||
epochs=args.epochs, | |||||
log_frequency = args.log_frequency) | |||||
# result = {"Accuracy": [], "Cost_time": ''} | |||||
best_top1 = 0. | |||||
start_time = time.time() | |||||
for epoch in range(args.epochs): | |||||
drop_prob = args.drop_path_prob * epoch / args.epochs | |||||
model.drop_path_prob(drop_prob) | |||||
# training | |||||
retrainer.train(train_loader, model, optimizer, criterion, epoch) | |||||
# validation | |||||
cur_step = (epoch + 1) * len(train_loader) | |||||
top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) | |||||
# 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} | |||||
logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') | |||||
# result["Accuracy"].append(top1) | |||||
best_top1 = max(best_top1, top1) | |||||
lr_scheduler.step() | |||||
logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) | |||||
cost_time = time.time() - start_time | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) | |||||
# result["Cost_time"] = str(cost_time) + ' s' | |||||
# dump_global_result(args.result_path, result) | |||||
save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) | |||||
logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) |
@@ -0,0 +1,21 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
from argparse import ArgumentParser | |||||
class PCdartsSelector(Selector): | |||||
def __init__(self, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
def fit(self): | |||||
pass | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("DARTS select") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
args = parser.parse_args() | |||||
darts_selector = PCdartsSelector(True) | |||||
darts_selector.fit() |
@@ -0,0 +1,93 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import time | |||||
from argparse import ArgumentParser | |||||
from model import CNN | |||||
import torch | |||||
import torch.nn as nn | |||||
from pytorch.callbacks import BestArchitectureCheckpoint, LRSchedulerCallback | |||||
from pytorch.pcdarts import PCdartsMutator | |||||
from pytorch.darts import DartsTrainer | |||||
from pytorch.darts.utils import accuracy | |||||
from pytorch.darts import datasets | |||||
from pytorch.utils import * | |||||
logger = logging.getLogger(__name__) | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("PCDARTS train") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='../data/', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='.0/result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='.0/log', help="log for info") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search space of PDARTS") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument('--model_lr', type=float, default=0.1, help='learning rate for training model weights') | |||||
parser.add_argument('--arch_lr', type=float, default=6e-4, help='learning rate for training architecture') | |||||
parser.add_argument("--nodes", default=4, type=int) | |||||
parser.add_argument("--layers", default=8, type=int) | |||||
parser.add_argument("--channels", default=16, type=int) | |||||
parser.add_argument("--batch_size", default=96, type=int) | |||||
parser.add_argument("--log_frequency", default=50, type=int) | |||||
parser.add_argument("--class_num", default=10, type=int, help="cifar10") | |||||
parser.add_argument("--epochs", default=5, type=int) | |||||
parser.add_argument("--pre_epochs", default=15, type=int, help='pre epochs to train weight only') | |||||
parser.add_argument("--k", default=4, type=int, help="channel portion of channel shuffle") | |||||
parser.add_argument("--unrolled", default=False, action="store_true") | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) | |||||
init_logger(args.log_path, "info") | |||||
logger.info(args) | |||||
set_seed(args.trial_id) | |||||
logger.info("loading data") | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", root=args.data_dir) | |||||
model = CNN(32, 3, args.channels, args.class_num, args.layers, n_nodes=args.nodes, k=args.k) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optim = torch.optim.SGD(model.parameters(), args.model_lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) | |||||
logger.info("initializing trainer") | |||||
trainer = DartsTrainer(model, | |||||
loss=criterion, | |||||
metrics=lambda output, target: accuracy(output, target, topk=(1,)), | |||||
optimizer=optim, | |||||
num_epochs=args.epochs, | |||||
dataset_train=dataset_train, | |||||
dataset_valid=dataset_valid, | |||||
mutator=PCdartsMutator(model), | |||||
batch_size=args.batch_size, | |||||
log_frequency=args.log_frequency, | |||||
arch_lr=args.arch_lr, | |||||
unrolled=args.unrolled, | |||||
result_path=args.result_path, | |||||
num_pre_epochs=args.pre_epochs, | |||||
search_space_path=args.search_space_path, | |||||
callbacks= | |||||
[LRSchedulerCallback(lr_scheduler), BestArchitectureCheckpoint(args.best_selected_space_path, args.epochs)]) | |||||
logger.info("training") | |||||
t1 = time.time() | |||||
trainer.train() | |||||
# res_json = trainer.result | |||||
cost_time = time.time() - t1 | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) | |||||
# res_json["Cost_time"] = str(cost_time) + ' s' | |||||
# dump_global_result(args.result_path, res_json) |
@@ -0,0 +1,146 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
import logging | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from collections import OrderedDict | |||||
from pytorch.mutator import Mutator | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
_logger = logging.getLogger(__name__) | |||||
class PCdartsMutator(Mutator): | |||||
""" | |||||
Connects the model in a PC-DARTS (differentiable) way. | |||||
Two connections are automatically inserted for each LayerChoice and InputChoice, when these connections are selected by softmax function. | |||||
Ops on the LayerChoice are selected by max top-k probabilities. But channels in the all candicate predecessors on the InputChoice are weighted sum | |||||
There is no op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false`` | |||||
(not chosen). | |||||
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based | |||||
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice | |||||
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0. | |||||
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the | |||||
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will | |||||
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully. | |||||
Attributes | |||||
---------- | |||||
choices: ParameterDict | |||||
dict that maps keys of LayerChoices to weighted-connection float tensors. | |||||
""" | |||||
def __init__(self, model): | |||||
super().__init__(model) | |||||
self.choices = nn.ParameterDict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1)) | |||||
if isinstance(mutable, InputChoice): | |||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.n_candidates)) | |||||
def device(self): | |||||
for v in self.choices.values(): | |||||
return v.device | |||||
def sample_search(self): | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1] | |||||
elif isinstance(mutable, InputChoice): | |||||
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1) | |||||
return result | |||||
def sample_final(self): | |||||
result = dict() | |||||
edges_max = dict() | |||||
choices = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
# multiply the normalized coefficients together to select top-1 op in each LayerChoice | |||||
predecessor_idx = int(mutable.key[-1]) | |||||
inputchoice_key = mutable.key[:-2] + "switch" | |||||
choices[mutable.key] = self.choices[mutable.key] * self.choices[inputchoice_key][predecessor_idx] | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
# select non-none top-1 op | |||||
max_val, index = torch.max(F.softmax(choices[mutable.key], dim=-1)[:-1], 0) | |||||
edges_max[mutable.key] = max_val | |||||
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, InputChoice): | |||||
if mutable.n_chosen is not None: | |||||
weights = [] | |||||
for src_key in mutable.choose_from: | |||||
if src_key not in edges_max: | |||||
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key) | |||||
weights.append(edges_max.get(src_key, 0.)) | |||||
weights = torch.tensor(weights) # pylint: disable=not-callable | |||||
# select top-2 strongest predecessor | |||||
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen) | |||||
selected_multihot = [] | |||||
for i, src_key in enumerate(mutable.choose_from): | |||||
if i not in topk_edge_indices and src_key in result: | |||||
# If an edge is never selected, there is no need to calculate any op on this edge. | |||||
# This is to eliminate redundant calculation. | |||||
result[src_key] = torch.zeros_like(result[src_key]) | |||||
selected_multihot.append(i in topk_edge_indices) | |||||
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable | |||||
else: | |||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable | |||||
return result | |||||
def _generate_search_space(self): | |||||
""" | |||||
Generate search space from mutables. | |||||
Here is the search space format: | |||||
:: | |||||
{ key_name: {"_type": "layer_choice", | |||||
"_value": ["conv1", "conv2"]} } | |||||
{ key_name: {"_type": "input_choice", | |||||
"_value": {"candidates": ["in1", "in2"], | |||||
"n_chosen": 1}} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the generated search space | |||||
""" | |||||
res = OrderedDict() | |||||
res["op_list"] = OrderedDict() | |||||
res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()} | |||||
keys = [] | |||||
for mutable in self.mutables: | |||||
# for now we only generate flattened search space | |||||
if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36: | |||||
break | |||||
if isinstance(mutable, LayerChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
val = mutable.names | |||||
if not res["op_list"]: | |||||
res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} | |||||
node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][node_type][key] = "op_list" | |||||
keys.append(key) | |||||
elif isinstance(mutable, InputChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][node_type][key] = {"_type": "input_choice", | |||||
"_value": {"candidates": mutable.choose_from, | |||||
"n_chosen": mutable.n_chosen}} | |||||
keys.append(key) | |||||
else: | |||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) | |||||
return res | |||||
@@ -0,0 +1,81 @@ | |||||
# train stage | |||||
`python pcdarts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --layers 5 --model_lr 0.025 --arch_lr 3e-4 --epochs 2 --pre_epochs 1 --batch_size 64 --channels 16` | |||||
# select stage | |||||
`python pcdarts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json' ` | |||||
# retrain stage | |||||
`python pcdarts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 2 --lr 0.01 --layers 20 --channels 36` | |||||
# output file | |||||
`result.json` | |||||
``` | |||||
{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} | |||||
``` | |||||
`search_space.json` | |||||
``` | |||||
{ | |||||
"op_list": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"maxpool", | |||||
"avgpool", | |||||
"skipconnect", | |||||
"sepconv3x3", | |||||
"sepconv5x5", | |||||
"dilconv3x3", | |||||
"dilconv5x5", | |||||
"none" | |||||
] | |||||
}, | |||||
"search_space": { | |||||
"normal_n2_p0": "op_list", | |||||
"normal_n2_p1": "op_list", | |||||
"normal_n2_switch": { | |||||
"_type": "input_choice", | |||||
"_value": { | |||||
"candidates": [ | |||||
"normal_n2_p0", | |||||
"normal_n2_p1" | |||||
], | |||||
"n_chosen": 2 | |||||
} | |||||
}, | |||||
... | |||||
} | |||||
``` | |||||
`best_selected_space.json` | |||||
``` | |||||
{ | |||||
"normal_n2_p0": "dilconv5x5", | |||||
"normal_n2_p1": "dilconv5x5", | |||||
"normal_n2_switch": [ | |||||
"normal_n2_p0", | |||||
"normal_n2_p1" | |||||
], | |||||
"normal_n3_p0": "sepconv3x3", | |||||
"normal_n3_p1": "dilconv5x5", | |||||
"normal_n3_p2": [], | |||||
"normal_n3_switch": [ | |||||
"normal_n3_p0", | |||||
"normal_n3_p1" | |||||
], | |||||
"normal_n4_p0": [], | |||||
"normal_n4_p1": "dilconv5x5", | |||||
"normal_n4_p2": "sepconv5x5", | |||||
"normal_n4_p3": [], | |||||
"normal_n4_switch": [ | |||||
"normal_n4_p1", | |||||
"normal_n4_p2" | |||||
], | |||||
... | |||||
} | |||||
``` |
@@ -0,0 +1,2 @@ | |||||
from pytorch.pdarts.pdartsmutator import PdartsMutator | |||||
from pytorch.pdarts.pdartstrainer import PdartsTrainer |
@@ -0,0 +1,177 @@ | |||||
# Copyright (c) Microsoft Corporation. | |||||
# Licensed under the MIT license. | |||||
from collections import OrderedDict | |||||
import torch | |||||
import torch.nn as nn | |||||
from pytorch import mutables | |||||
from pytorch.darts import ops | |||||
class AuxiliaryHead(nn.Module): | |||||
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ | |||||
def __init__(self, input_size, C, n_classes): | |||||
""" assuming input size 7x7 or 8x8 """ | |||||
assert input_size in [7, 8] | |||||
super().__init__() | |||||
self.net = nn.Sequential( | |||||
nn.ReLU(inplace=True), | |||||
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out | |||||
nn.Conv2d(C, 128, kernel_size=1, bias=False), | |||||
nn.BatchNorm2d(128), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out | |||||
nn.BatchNorm2d(768), | |||||
nn.ReLU(inplace=True) | |||||
) | |||||
self.linear = nn.Linear(768, n_classes) | |||||
def forward(self, x): | |||||
out = self.net(x) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
return logits | |||||
class Node(nn.Module): | |||||
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, search, dropout_rate): | |||||
super().__init__() | |||||
self.dropout_rate = dropout_rate | |||||
self.ops = nn.ModuleList() | |||||
choice_keys = [] | |||||
for i in range(num_prev_nodes): | |||||
stride = 2 if i < num_downsample_connect else 1 | |||||
choice_keys.append("{}_p{}".format(node_id, i)) | |||||
skip_op = nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False) | |||||
# In search, op-level dropout for skip-connect | |||||
if search and self.dropout_rate > 0: | |||||
skip_op = nn.Sequential(skip_op, nn.Dropout(self.dropout_rate)) | |||||
self.ops.append( | |||||
mutables.LayerChoice(OrderedDict([ | |||||
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), | |||||
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), | |||||
("skipconnect", skip_op), | |||||
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)), | |||||
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), | |||||
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), | |||||
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) | |||||
]), key=choice_keys[-1])) | |||||
# In retrain, DropPath for non skip-connect, p in DropPath default to 0 | |||||
self.drop_path = ops.DropPath() | |||||
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) | |||||
def forward(self, prev_nodes): | |||||
assert len(self.ops) == len(prev_nodes) | |||||
output = [] | |||||
for op, node in zip(self.ops, prev_nodes): | |||||
out = op(node) | |||||
# In retrain | |||||
if out is not None: | |||||
if not isinstance(op, nn.Identity): | |||||
out = self.drop_path(out) | |||||
else: | |||||
out = None | |||||
output.append(out) | |||||
# out = [op(node) for op, node in zip(self.ops, prev_nodes)] | |||||
# out = [self.drop_path(o) if o is not None else None for o in out] | |||||
return self.input_switch(output) | |||||
class Cell(nn.Module): | |||||
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, search, dropout_rate): | |||||
super().__init__() | |||||
self.reduction = reduction | |||||
self.n_nodes = n_nodes | |||||
# If previous cell is reduction cell, current input size does not match with | |||||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. | |||||
if reduction_p: | |||||
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) | |||||
else: | |||||
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) | |||||
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) | |||||
# generate dag | |||||
self.mutable_ops = nn.ModuleList() | |||||
for depth in range(2, self.n_nodes + 2): | |||||
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, 2 if reduction else 0, search, dropout_rate)) | |||||
def forward(self, s0, s1): | |||||
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. | |||||
tensors = [self.preproc0(s0), self.preproc1(s1)] | |||||
for node in self.mutable_ops: | |||||
cur_tensor = node(tensors) | |||||
tensors.append(cur_tensor) | |||||
output = torch.cat(tensors[2:], dim=1) | |||||
return output | |||||
class CNN(nn.Module): | |||||
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, dropout_rate, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True): | |||||
super().__init__() | |||||
self.in_channels = in_channels | |||||
self.channels = channels | |||||
self.n_classes = n_classes | |||||
self.n_layers = n_layers | |||||
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 | |||||
c_cur = stem_multiplier * self.channels | |||||
self.stem = nn.Sequential( | |||||
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), | |||||
nn.BatchNorm2d(c_cur) | |||||
) | |||||
# for the first cell, stem is used for both s0 and s1 | |||||
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. | |||||
channels_pp, channels_p, c_cur = c_cur, c_cur, channels | |||||
self.cells = nn.ModuleList() | |||||
reduction_p, reduction = False, False | |||||
for i in range(n_layers): | |||||
reduction_p, reduction = reduction, False | |||||
# Reduce featuremap size and double channels in 1/3 and 2/3 layer. | |||||
if i in [n_layers // 3, 2 * n_layers // 3]: | |||||
c_cur *= 2 | |||||
reduction = True | |||||
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, search, dropout_rate) | |||||
self.cells.append(cell) | |||||
c_cur_out = c_cur * n_nodes | |||||
channels_pp, channels_p = channels_p, c_cur_out | |||||
if i == self.aux_pos: | |||||
self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) | |||||
self.gap = nn.AdaptiveAvgPool2d(1) | |||||
self.linear = nn.Linear(channels_p, n_classes) | |||||
def forward(self, x): | |||||
s0 = s1 = self.stem(x) | |||||
aux_logits = None | |||||
for i, cell in enumerate(self.cells): | |||||
s0, s1 = s1, cell(s0, s1) | |||||
if i == self.aux_pos and self.training: | |||||
aux_logits = self.aux_head(s1) | |||||
out = self.gap(s1) | |||||
out = out.view(out.size(0), -1) # flatten | |||||
logits = self.linear(out) | |||||
if aux_logits is not None: | |||||
return logits, aux_logits | |||||
return logits | |||||
def drop_path_prob(self, p, search=True): | |||||
if search: | |||||
for module in self.modules(): | |||||
# In search, update dropout rate | |||||
if isinstance(module, nn.Sequential) and isinstance(module[0], nn.Identity): | |||||
module[1].dropout_rate = p | |||||
else: | |||||
# In retrain, update ops.DropPath | |||||
for module in self.modules(): | |||||
if isinstance(module, ops.DropPath): | |||||
module.p = p |
@@ -0,0 +1,203 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import os | |||||
import logging | |||||
import time | |||||
import json | |||||
from argparse import ArgumentParser | |||||
import torch | |||||
import torch.nn as nn | |||||
# from torch.utils.tensorboard import SummaryWriter | |||||
from model import CNN | |||||
from pytorch.fixed import apply_fixed_architecture | |||||
from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter | |||||
from pytorch.darts import utils | |||||
from pytorch.darts import datasets | |||||
from pytorch.retrainer import Retrainer | |||||
logger = logging.getLogger(__name__) | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
# writer = SummaryWriter() | |||||
class PdartsRetrainer(Retrainer): | |||||
def __init__(self, aux_weight, grad_clip, epochs, log_frequency): | |||||
self.aux_weight = aux_weight | |||||
self.grad_clip = grad_clip | |||||
self.epochs = epochs | |||||
self.log_frequency = log_frequency | |||||
def train(self, train_loader, model, optimizer, criterion, epoch): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
cur_step = epoch * len(train_loader) | |||||
cur_lr = optimizer.param_groups[0]["lr"] | |||||
logger.info("Epoch %d LR %.6f", epoch, cur_lr) | |||||
# writer.add_scalar("lr", cur_lr, global_step=cur_step) | |||||
model.train() | |||||
for step, (x, y) in enumerate(train_loader): | |||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = x.size(0) | |||||
optimizer.zero_grad() | |||||
logits, aux_logits = model(x) | |||||
loss = criterion(logits, y) | |||||
if self.aux_weight > 0.: | |||||
loss += self.aux_weight * criterion(aux_logits, y) | |||||
loss.backward() | |||||
# gradient clipping | |||||
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) | |||||
optimizer.step() | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
# writer.add_scalar("loss/train", loss.item(), global_step=cur_step) | |||||
# writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step) | |||||
# writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step) | |||||
if step % self.log_frequency == 0 or step == len(train_loader) - 1: | |||||
logger.info( | |||||
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
cur_step += 1 | |||||
logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
def validate(self, valid_loader, model, criterion, epoch, cur_step): | |||||
top1 = AverageMeter("top1") | |||||
top5 = AverageMeter("top5") | |||||
losses = AverageMeter("losses") | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
for step, (X, y) in enumerate(valid_loader): | |||||
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) | |||||
bs = X.size(0) | |||||
logits = model(X) | |||||
loss = criterion(logits, y) | |||||
accuracy = utils.accuracy(logits, y, topk=(1, 5)) | |||||
losses.update(loss.item(), bs) | |||||
top1.update(accuracy["acc1"], bs) | |||||
top5.update(accuracy["acc5"], bs) | |||||
if step % self.log_frequency == 0 or step == len(valid_loader) - 1: | |||||
logger.info( | |||||
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " | |||||
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( | |||||
epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses, | |||||
top1=top1, top5=top5)) | |||||
# writer.add_scalar("loss/test", losses.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc1/test", top1.avg, global_step=cur_step) | |||||
# writer.add_scalar("acc5/test", top5.avg, global_step=cur_step) | |||||
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg)) | |||||
return top1.avg | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("Pdarts retrain") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='./', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='./result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='.0/log', help="log for info") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument("--best_checkpoint_dir", type=str, | |||||
default='', help="default name is best_checkpoint_epoch{}.pth") | |||||
parser.add_argument('--trial_id', type=int, default=0, metavar='N', | |||||
help='trial_id,start from 0') | |||||
parser.add_argument("--layers", default=20, type=int) | |||||
parser.add_argument("--batch_size", default=96, type=int) | |||||
parser.add_argument("--log_frequency", default=10, type=int) | |||||
parser.add_argument("--epochs", default=600, type=int) | |||||
parser.add_argument("--lr", default=0.025, type=float) | |||||
parser.add_argument("--channels", default=36, type=int) | |||||
parser.add_argument("--aux_weight", default=0.4, type=float) | |||||
parser.add_argument("--drop_path_prob", default=0.3, type=float) | |||||
parser.add_argument("--workers", default=4) | |||||
parser.add_argument("--grad_clip", default=5., type=float) | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir) | |||||
init_logger(args.log_path) | |||||
logger.info(args) | |||||
set_seed(args.trial_id) | |||||
logger.info("loading data") | |||||
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir) | |||||
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True, search=False, dropout_rate=0.0) | |||||
if isinstance(args.best_selected_space_path, str): | |||||
with open(args.best_selected_space_path) as f: | |||||
fixed_arc = json.load(f) | |||||
apply_fixed_architecture(model, fixed_arc=fixed_arc["best_selected_space"]) | |||||
criterion = nn.CrossEntropyLoss() | |||||
model.to(device) | |||||
criterion.to(device) | |||||
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6) | |||||
train_loader = torch.utils.data.DataLoader(dataset_train, | |||||
batch_size=args.batch_size, | |||||
shuffle=True, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
valid_loader = torch.utils.data.DataLoader(dataset_valid, | |||||
batch_size=args.batch_size, | |||||
shuffle=False, | |||||
num_workers=args.workers, | |||||
pin_memory=True) | |||||
retrainer = PdartsRetrainer(aux_weight=args.aux_weight, | |||||
grad_clip=args.grad_clip, | |||||
epochs=args.epochs, | |||||
log_frequency = args.log_frequency) | |||||
best_top1 = 0. | |||||
start_time = time.time() | |||||
for epoch in range(args.epochs): | |||||
drop_prob = args.drop_path_prob * epoch / args.epochs | |||||
model.drop_path_prob(drop_prob) | |||||
# training | |||||
retrainer.train(train_loader, model, optimizer, criterion, epoch) | |||||
# validation | |||||
cur_step = (epoch + 1) * len(train_loader) | |||||
top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step) | |||||
# 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}} | |||||
logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n') | |||||
best_top1 = max(best_top1, top1) | |||||
lr_scheduler.step() | |||||
logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) | |||||
cost_time = time.time() - start_time | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) | |||||
# result["Cost_time"] = str(cost_time) + ' s' | |||||
# dump_global_result(args.result_path, result) | |||||
save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch) | |||||
logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)))) |
@@ -0,0 +1,22 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from pytorch.selector import Selector | |||||
from argparse import ArgumentParser | |||||
class PdartsSelector(Selector): | |||||
def __init__(self, single_candidate=True): | |||||
super().__init__(single_candidate) | |||||
def fit(self): | |||||
pass | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("PDARTS select") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
args = parser.parse_args() | |||||
darts_selector = PdartsSelector(True) | |||||
darts_selector.fit() | |||||
@@ -0,0 +1,79 @@ | |||||
import sys | |||||
sys.path.append('..'+ '/' + '..') | |||||
import time | |||||
import logging | |||||
from argparse import ArgumentParser | |||||
from pdartstrainer import PdartsTrainer | |||||
from pytorch.utils import mkdirs, set_seed, init_logger, list_str2int | |||||
logger = logging.getLogger(__name__) | |||||
if __name__ == "__main__": | |||||
parser = ArgumentParser("pdarts") | |||||
parser.add_argument("--data_dir", type=str, | |||||
default='../data/', help="search_space json file") | |||||
parser.add_argument("--result_path", type=str, | |||||
default='0/result.json', help="training result") | |||||
parser.add_argument("--log_path", type=str, | |||||
default='0/log', help="log for info") | |||||
parser.add_argument("--search_space_path", type=str, | |||||
default='./search_space.json', help="search space of PDARTS") | |||||
parser.add_argument("--best_selected_space_path", type=str, | |||||
default='./best_selected_space.json', help="final best selected space") | |||||
parser.add_argument('--trial_id', type=int, default=0, help='for ensuring reproducibility ') | |||||
parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights') | |||||
parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture') | |||||
parser.add_argument("--epochs", default=2, type=int) | |||||
parser.add_argument("--pre_epochs", default=15, type=int) | |||||
parser.add_argument("--batch_size", default=96, type=int) | |||||
parser.add_argument("--init_layers", default=5, type=int) | |||||
parser.add_argument('--add_layers', default=[0, 6, 12], nargs='+', type=int, help='add layers in each stage') | |||||
parser.add_argument('--dropped_ops', default=[3, 2, 1], nargs='+', type=int, help='drop ops in each stage') | |||||
parser.add_argument('--dropout_rates', default=[0.1, 0.4, 0.7], nargs='+', type=float, help='drop ops probability in each stage') | |||||
# parser.add_argument('--add_layers', action='append', help='add layers in each stage') | |||||
# parser.add_argument('--dropped_ops', action='append', help='drop ops in each stage') | |||||
# parser.add_argument('--dropout_rates', action='append', help='drop ops probability in each stage') | |||||
parser.add_argument("--channels", default=16, type=int) | |||||
parser.add_argument("--log_frequency", default=50, type=int) | |||||
parser.add_argument("--class_num", default=10, type=int) | |||||
parser.add_argument("--unrolled", default=False, action="store_true") | |||||
args = parser.parse_args() | |||||
mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path) | |||||
init_logger(args.log_path, "info") | |||||
set_seed(args.trial_id) | |||||
# args.add_layers = list_str2int(args.add_layers) | |||||
# args.dropped_ops = list_str2int(args.dropped_ops) | |||||
# args.dropout_rates = list_str2int(args.dropout_rates) | |||||
logger.info(args) | |||||
logger.info("initializing pdarts trainer") | |||||
trainer = PdartsTrainer( | |||||
init_layers=args.init_layers, | |||||
pdarts_num_layers=args.add_layers, | |||||
pdarts_num_to_drop=args.dropped_ops, | |||||
pdarts_dropout_rates=args.dropout_rates, | |||||
num_epochs=args.epochs, | |||||
num_pre_epochs=args.pre_epochs, | |||||
model_lr=args.model_lr, | |||||
arch_lr=args.arch_lr, | |||||
batch_size=args.batch_size, | |||||
class_num=args.class_num, | |||||
channels=args.channels, | |||||
result_path=args.result_path, | |||||
log_frequency=args.log_frequency, | |||||
unrolled=args.unrolled, | |||||
data_dir = args.data_dir, | |||||
search_space_path=args.search_space_path, | |||||
best_selected_space_path=args.best_selected_space_path | |||||
) | |||||
logger.info("training") | |||||
start_time = time.time() | |||||
trainer.train(validate=True) | |||||
# result = trainer.result | |||||
cost_time = time.time() - start_time | |||||
# 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}} | |||||
logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}) | |||||
with open(args.result_path, "a") as file: | |||||
file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})) |
@@ -0,0 +1,201 @@ | |||||
import copy | |||||
import numpy as np | |||||
import torch | |||||
import logging | |||||
from collections import OrderedDict | |||||
from torch import nn | |||||
from pytorch.darts.dartsmutator import DartsMutator | |||||
from pytorch.mutables import LayerChoice, InputChoice | |||||
logger = logging.getLogger(__name__) | |||||
class PdartsMutator(DartsMutator): | |||||
""" | |||||
It works with PdartsTrainer to calculate ops weights, | |||||
and drop weights in different PDARTS epochs. | |||||
""" | |||||
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): | |||||
self.pdarts_epoch_index = pdarts_epoch_index | |||||
self.pdarts_num_to_drop = pdarts_num_to_drop | |||||
# save the last two switches and choices for restrict skip | |||||
self.last_two_switches = None | |||||
self.last_two_choices = None | |||||
if switches is None: | |||||
self.switches = {} | |||||
else: | |||||
self.switches = switches | |||||
super(PdartsMutator, self).__init__(model) | |||||
# this loop go through mutables with different keys, | |||||
# it's mainly to update length of choices. | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))]) | |||||
# choices = self.choices[mutable.key] | |||||
operations_count = np.sum(switches) | |||||
# +1 and -1 are caused by zero operation in darts network | |||||
# the zero operation is not in choices list(switches) in network, but its weight are in, | |||||
# so it needs one more weights and switch for zero. | |||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1)) | |||||
self.switches[mutable.key] = switches | |||||
# update LayerChoice instances in model, | |||||
# it's physically remove dropped choices operations. | |||||
for module in self.model.modules(): | |||||
if isinstance(module, LayerChoice): | |||||
switches = self.switches.get(module.key) | |||||
choices = self.choices[module.key] | |||||
if len(module) > len(choices): | |||||
# from last to first, so that it won't effect previous indexes after removed one. | |||||
for index in range(len(switches)-1, -1, -1): | |||||
if switches[index] == False: | |||||
del module[index] | |||||
assert len(module) <= len(choices), "Failed to remove dropped choices." | |||||
def export(self, last, switches): | |||||
# In last pdarts_epoches, need to restrict skipconnection | |||||
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length. | |||||
if last: | |||||
# restrict Up to 2 skipconnect (normal cell only) | |||||
name = "normal" | |||||
max_num = 2 | |||||
skip_num = self.check_skip_num(name, switches) | |||||
logger.info("Initially, the number of skipconnect is {}.".format(skip_num)) | |||||
while skip_num > max_num: | |||||
logger.info("Restricting {} skipconnect to {}.".format(skip_num, max_num)) | |||||
logger.info("Original normal_switch is {}.".format(switches)) | |||||
# update self.choices setting skip prob to 0 and self.switches setting skip prob to False | |||||
switches = self.delete_min_sk(name, switches) | |||||
logger.info("Restricted normal_switch is {}.".format(switches)) | |||||
skip_num = self.check_skip_num(name, switches) | |||||
# from bool result convert to human readable by Mutator export() | |||||
results = super().sample_final() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
# As some operations are dropped physically, | |||||
# so it needs to fill back false to track dropped operations. | |||||
trained_result = results[mutable.key] | |||||
trained_index = 0 | |||||
switches = self.switches[mutable.key] | |||||
result = torch.Tensor(switches).bool() | |||||
for index in range(len(result)): | |||||
if result[index]: | |||||
result[index] = trained_result[trained_index] | |||||
trained_index += 1 | |||||
results[mutable.key] = result | |||||
return results | |||||
def drop_paths(self): | |||||
""" | |||||
This method is called when a PDARTS epoch is finished. | |||||
It prepares switches for next epoch. | |||||
candidate operations with False switch will be doppped in next epoch. | |||||
""" | |||||
all_switches = copy.deepcopy(self.switches) | |||||
for key in all_switches: | |||||
switches = all_switches[key] | |||||
idxs = [] | |||||
for j in range(len(switches)): | |||||
if switches[j]: | |||||
idxs.append(j) | |||||
sorted_weights = self.choices[key].data.cpu().numpy()[:-1] | |||||
drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]] | |||||
for idx in drop: | |||||
switches[idxs[idx]] = False | |||||
return all_switches | |||||
def check_skip_num(self, name, switches): | |||||
counter = 0 | |||||
for key in switches: | |||||
if name in key: | |||||
# zero operation not in switches, so "skipconnect" in 2 | |||||
if switches[key][2]: | |||||
counter += 1 | |||||
return counter | |||||
def delete_min_sk(self, name, switches): | |||||
def _get_sk_idx(key, switches): | |||||
if not switches[key][2]: | |||||
idx = -1 | |||||
else: | |||||
idx = 0 | |||||
for i in range(2): | |||||
# switches has 1 True, self.switches has 2 True | |||||
if self.switches[key][i]: | |||||
idx += 1 | |||||
return idx | |||||
sk_choices = [1.0 for i in range(14)] | |||||
sk_keys = [None for i in range(14)] # key has skip connection | |||||
sk_choices_idx = -1 | |||||
for key in switches: | |||||
if name in key: | |||||
# default key in order | |||||
sk_choices_idx += 1 | |||||
idx = _get_sk_idx(key, switches) | |||||
if not idx == -1: | |||||
sk_keys[sk_choices_idx] = key | |||||
sk_choices[sk_choices_idx] = self.choices[key][idx] | |||||
min_sk_idx = np.argmin(sk_choices) | |||||
idx = _get_sk_idx(sk_keys[min_sk_idx], switches) | |||||
# modify self.choices or copy.deepcopy ? | |||||
self.choices[sk_keys[min_sk_idx]][idx] = 0.0 | |||||
# modify self.switches or copy.deepcopy ? | |||||
# self.switches indicate last two switches, and switches indicate present(last) switches | |||||
self.switches[sk_keys[min_sk_idx]][2] = False | |||||
switches[sk_keys[min_sk_idx]][2] = False | |||||
return switches | |||||
def _generate_search_space(self): | |||||
""" | |||||
Generate search space from mutables. | |||||
Here is the search space format: | |||||
:: | |||||
{ key_name: {"_type": "layer_choice", | |||||
"_value": ["conv1", "conv2"]} } | |||||
{ key_name: {"_type": "input_choice", | |||||
"_value": {"candidates": ["in1", "in2"], | |||||
"n_chosen": 1}} } | |||||
Returns | |||||
------- | |||||
dict | |||||
the generated search space | |||||
""" | |||||
res = OrderedDict() | |||||
res["op_list"] = OrderedDict() | |||||
res["search_space"] = {"reduction_cell": OrderedDict(), "normal_cell": OrderedDict()} | |||||
keys = [] | |||||
for mutable in self.mutables: | |||||
# for now we only generate flattened search space | |||||
if (len(res["search_space"]["reduction_cell"]) + len(res["search_space"]["normal_cell"])) >= 36: | |||||
break | |||||
if isinstance(mutable, LayerChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
val = mutable.names | |||||
if not res["op_list"]: | |||||
res["op_list"] = {"_type": "layer_choice", "_value": val + ["none"]} | |||||
node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][node_type][key] = "op_list" | |||||
keys.append(key) | |||||
elif isinstance(mutable, InputChoice): | |||||
key = mutable.key | |||||
if key not in keys: | |||||
node_type = "normal_cell" if "normal" in key else "reduction_cell" | |||||
res["search_space"][node_type][key] = {"_type": "input_choice", | |||||
"_value": {"candidates": mutable.choose_from, | |||||
"n_chosen": mutable.n_chosen}} | |||||
keys.append(key) | |||||
else: | |||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) | |||||
return res |
@@ -0,0 +1,167 @@ | |||||
import os | |||||
import logging | |||||
import torch | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
from collections import OrderedDict | |||||
import json | |||||
from pytorch.callbacks import LRSchedulerCallback | |||||
from pytorch.trainer import BaseTrainer, TorchTensorEncoder | |||||
from pytorch.utils import dump_global_result | |||||
from model import CNN | |||||
from pdartsmutator import PdartsMutator | |||||
from pytorch.darts.utils import accuracy | |||||
from pytorch.darts import datasets | |||||
from pytorch.darts.dartstrainer import DartsTrainer | |||||
logger = logging.getLogger(__name__) | |||||
class PdartsTrainer(BaseTrainer): | |||||
""" | |||||
This trainer implements the PDARTS algorithm. | |||||
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network. | |||||
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows. | |||||
pdarts_num_layers means how many layers more than first epoch. | |||||
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch. | |||||
So that the grew network can in similar size. | |||||
""" | |||||
def __init__(self, init_layers, pdarts_num_layers, pdarts_num_to_drop, pdarts_dropout_rates, num_epochs, num_pre_epochs, model_lr, class_num, | |||||
arch_lr, channels, batch_size, result_path, log_frequency, unrolled, data_dir, search_space_path, | |||||
best_selected_space_path, device=None, workers=4): | |||||
super(PdartsTrainer, self).__init__() | |||||
self.init_layers = init_layers | |||||
self.class_num = class_num | |||||
self.channels = channels | |||||
self.model_lr = model_lr | |||||
self.num_epochs = num_epochs | |||||
self.class_num = class_num | |||||
self.pdarts_num_layers = pdarts_num_layers | |||||
self.pdarts_num_to_drop = pdarts_num_to_drop | |||||
self.pdarts_dropout_rates = pdarts_dropout_rates | |||||
self.pdarts_epoches = len(pdarts_num_to_drop) | |||||
self.search_space_path = search_space_path | |||||
self.best_selected_space_path = best_selected_space_path | |||||
logger.info("loading data") | |||||
dataset_train, dataset_valid = datasets.get_dataset( | |||||
"cifar10", root=data_dir) | |||||
self.darts_parameters = { | |||||
"metrics": lambda output, target: accuracy(output, target, topk=(1,)), | |||||
"arch_lr": arch_lr, | |||||
"num_epochs": num_epochs, | |||||
"num_pre_epochs": num_pre_epochs, | |||||
"dataset_train": dataset_train, | |||||
"dataset_valid": dataset_valid, | |||||
"batch_size": batch_size, | |||||
"result_path": result_path, | |||||
"workers": workers, | |||||
"device": device, | |||||
"log_frequency": log_frequency, | |||||
"unrolled": unrolled, | |||||
"search_space_path": None | |||||
} | |||||
def train(self, validate=False): | |||||
switches = None | |||||
last = False | |||||
for epoch in range(self.pdarts_epoches): | |||||
if epoch == self.pdarts_epoches - 1: | |||||
last = True | |||||
# create network for each stage | |||||
layers = self.init_layers + self.pdarts_num_layers[epoch] | |||||
init_dropout_rate = float(self.pdarts_dropout_rates[epoch]) | |||||
model = CNN(32, 3, self.channels, self.class_num, layers, | |||||
init_dropout_rate, n_nodes=4, search=True) | |||||
criterion = nn.CrossEntropyLoss() | |||||
optim = torch.optim.SGD( | |||||
model.parameters(), self.model_lr, momentum=0.9, weight_decay=3.0E-4) | |||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |||||
optim, self.num_epochs, eta_min=0.001) | |||||
logger.info( | |||||
"############Start PDARTS training epoch %s############", epoch) | |||||
self.mutator = PdartsMutator( | |||||
model, epoch, self.pdarts_num_to_drop, switches) | |||||
if epoch == 0: | |||||
# only write original search space in first stage | |||||
search_space = self.mutator._generate_search_space() | |||||
dump_global_result(self.search_space_path, | |||||
search_space) | |||||
darts_callbacks = [] | |||||
if lr_scheduler is not None: | |||||
darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) | |||||
# darts_callbacks.append(ArchitectureCheckpoint( | |||||
# os.path.join(self.selected_space_path, "stage_{}".format(epoch)))) | |||||
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, | |||||
optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters) | |||||
for train_epoch in range(self.darts_parameters["num_epochs"]): | |||||
for callback in darts_callbacks: | |||||
callback.on_epoch_begin(train_epoch) | |||||
# training | |||||
logger.info("Epoch %d Training", train_epoch) | |||||
if train_epoch < self.darts_parameters["num_pre_epochs"]: | |||||
dropout_rate = init_dropout_rate * \ | |||||
(self.darts_parameters["num_epochs"] - train_epoch - | |||||
1) / self.darts_parameters["num_epochs"] | |||||
else: | |||||
# scale_factor = 0.2 | |||||
dropout_rate = init_dropout_rate * \ | |||||
np.exp(-(epoch - | |||||
self.darts_parameters["num_pre_epochs"]) * 0.2) | |||||
model.drop_path_prob(search=True, p=dropout_rate) | |||||
self.trainer.train_one_epoch(train_epoch) | |||||
if validate: | |||||
# validation | |||||
logger.info("Epoch %d Validating", train_epoch + 1) | |||||
self.trainer.validate_one_epoch( | |||||
train_epoch, log_print=True if last else False) | |||||
for callback in darts_callbacks: | |||||
callback.on_epoch_end(train_epoch) | |||||
switches = self.mutator.drop_paths() | |||||
# In last pdarts_epoches, need to restrict skipconnection and save best structure | |||||
if last: | |||||
res = OrderedDict() | |||||
op_value = [value for value in search_space["op_list"]["_value"] if value != 'none'] | |||||
res["op_list"] = search_space["op_list"] | |||||
res["op_list"]["_value"] = op_value | |||||
res["best_selected_space"] = self.mutator.export(last, switches) | |||||
logger.info(res) | |||||
dump_global_result(self.best_selected_space_path, res) | |||||
def validate(self): | |||||
self.trainer.validate() | |||||
def export(self, file, last, switches): | |||||
self.mutator.export(last, switches) | |||||
mutator_export = self.mutator.export() | |||||
with open(file, "w") as f: | |||||
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) | |||||
def checkpoint(self, file_path, epoch): | |||||
if isinstance(self.model, nn.DataParallel): | |||||
child_model_state_dict = self.model.module.state_dict() | |||||
else: | |||||
child_model_state_dict = self.model.state_dict() | |||||
save_state = {'child_model_state_dict': child_model_state_dict, | |||||
'optimizer_state_dict': self.optimizer.state_dict(), | |||||
'epoch': epoch} | |||||
dest_path = os.path.join( | |||||
file_path, "best_checkpoint_epoch_{}.pth.tar".format(epoch)) | |||||
logger.info("Saving model to %s", dest_path) | |||||
torch.save(save_state, dest_path) | |||||
raise NotImplementedError("Not implemented yet") | |||||
@@ -0,0 +1,92 @@ | |||||
# train stage | |||||
`python pdarts_train.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --search_space_path 'experiment_id/search_space.json' --best_selected_space_path 'experiment_id/best_selected_space.json' --trial_id 0 --model_lr 0.025 --arch_lr 3e-4 --epochs 2 --pre_epochs 1 --batch_size 64 --channels 16 --init_layers 5 --add_layer 0 6 12 --dropped_ops 3 2 1 --dropout_rates 0.1 0.4 0.7` | |||||
# select stage | |||||
`python pdarts_select.py --best_selected_space_path 'experiment_id/best_selected_space.json'` | |||||
# retrain stage | |||||
`python pdarts_retrain.py --data_dir '../data/' --result_path 'trial_id/result.json' --log_path 'trial_id/log' --best_selected_space_path 'experiment_id/best_selected_space.json' --best_checkpoint_dir 'experiment_id/' --trial_id 0 --batch_size 96 --epochs 2 --lr 0.025 --layers 20 --channels 36` | |||||
# output file | |||||
`result.json` | |||||
``` | |||||
{'type': 'Accuracy', 'result': {'sequence': 0, 'category': 'epoch', 'value': 0.1}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 1, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 2, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 3, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Accuracy', 'result': {'sequence': 4, 'category': 'epoch', 'value': 0.0}} | |||||
{'type': 'Cost_time', 'result': {'value': '41.614346981048584 s'}} | |||||
``` | |||||
`search_space.json` | |||||
``` | |||||
{ | |||||
"op_list": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"maxpool", | |||||
"avgpool", | |||||
"skipconnect", | |||||
"sepconv3x3", | |||||
"sepconv5x5", | |||||
"dilconv3x3", | |||||
"dilconv5x5", | |||||
"none" | |||||
] | |||||
}, | |||||
"search_space": { | |||||
"normal_n2_p0": "op_list", | |||||
"normal_n2_p1": "op_list", | |||||
"normal_n2_switch": { | |||||
"_type": "input_choice", | |||||
"_value": { | |||||
"candidates": [ | |||||
"normal_n2_p0", | |||||
"normal_n2_p1" | |||||
], | |||||
"n_chosen": 2 | |||||
} | |||||
}, | |||||
... | |||||
} | |||||
``` | |||||
`best_selected_space.json` | |||||
``` | |||||
{ | |||||
{ | |||||
"op_list": { | |||||
"_type": "layer_choice", | |||||
"_value": [ | |||||
"maxpool", | |||||
"avgpool", | |||||
"skipconnect", | |||||
"sepconv3x3", | |||||
"sepconv5x5", | |||||
"dilconv3x3", | |||||
"dilconv5x5" | |||||
] | |||||
}, | |||||
"best_selected_space": { | |||||
"normal_n2_p0": [ | |||||
false, | |||||
false, | |||||
false, | |||||
false, | |||||
true, | |||||
false, | |||||
false | |||||
], | |||||
"normal_n2_p1": [ | |||||
true, | |||||
false, | |||||
false, | |||||
false, | |||||
false, | |||||
false, | |||||
false | |||||
], | |||||
... | |||||
} | |||||
``` |
@@ -0,0 +1,46 @@ | |||||
from abc import ABC, abstractmethod | |||||
class Retrainer(ABC): | |||||
""" | |||||
Train the best performance model from scratch without structure optimization. | |||||
To implement a new selector, users need to implement: | |||||
method: "train" | |||||
method: "__init__" | |||||
super().__init__() must be called in __init__ method | |||||
parameters: | |||||
----------- | |||||
candidates: candidates to be evaluated | |||||
""" | |||||
@abstractmethod | |||||
def train(self): | |||||
""" | |||||
Override the method to train. | |||||
""" | |||||
raise NotImplementedError | |||||
def validate(self): | |||||
""" | |||||
Override the method to validate. | |||||
""" | |||||
raise NotImplementedError | |||||
def export(self, file): | |||||
""" | |||||
Override the method to export to file. | |||||
Parameters | |||||
---------- | |||||
file : str | |||||
File path to export to. | |||||
""" | |||||
raise NotImplementedError | |||||
def checkpoint(self): | |||||
""" | |||||
Override to dump a checkpoint. | |||||
""" | |||||
raise NotImplementedError |
@@ -0,0 +1,56 @@ | |||||
from abc import ABC, abstractmethod | |||||
class Selector(ABC): | |||||
""" | |||||
choose the best model from a group of candidates. | |||||
To implement a new selector, users need to implement: | |||||
method: "fit" | |||||
method: "__init__" | |||||
super().__init__() must be called in __init__ method | |||||
parameters: | |||||
----------- | |||||
candidates: candidates to be evaluated | |||||
##### Examples ##### | |||||
# class HPOSelector(Selector): | |||||
# def __init__(self, *args, single_candidate=True): | |||||
# super().__init__(single_candidate) | |||||
# self.args = args | |||||
# def fit(self): | |||||
# | |||||
# # only one candatite, function passed | |||||
# | |||||
# pass | |||||
########### | |||||
""" | |||||
@abstractmethod | |||||
def __init__(self, single_candidate=True): | |||||
self.single_candidate = single_candidate | |||||
self._valid() | |||||
@abstractmethod | |||||
def fit(self, candidates=None): | |||||
""" | |||||
evaluate the candidates to select the best one. | |||||
any optimization algos could be implement here. | |||||
if the inputs has only one candidates, just return the candidate directly | |||||
""" | |||||
raise NotImplementedError | |||||
def _valid(self, ): | |||||
if self.single_candidate: | |||||
print("### single model, selecting finished ###") | |||||
exit(0) | |||||
@@ -0,0 +1 @@ | |||||
from .mutator import RandomMutator |
@@ -0,0 +1,36 @@ | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from nni.nas.pytorch.mutator import Mutator | |||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice | |||||
class RandomMutator(Mutator): | |||||
""" | |||||
Random mutator that samples a random candidate in the search space each time ``reset()``. | |||||
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior. | |||||
""" | |||||
def sample_search(self): | |||||
""" | |||||
Sample a random candidate. | |||||
""" | |||||
result = dict() | |||||
for mutable in self.mutables: | |||||
if isinstance(mutable, LayerChoice): | |||||
gen_index = torch.randint(high=len(mutable), size=(1, )) | |||||
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool() | |||||
elif isinstance(mutable, InputChoice): | |||||
if mutable.n_chosen is None: | |||||
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() | |||||
else: | |||||
perm = torch.randperm(mutable.n_candidates) | |||||
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)] | |||||
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable | |||||
return result | |||||
def sample_final(self): | |||||
""" | |||||
Same as :meth:`sample_search`. | |||||
""" | |||||
return self.sample_search() |
@@ -0,0 +1,3 @@ | |||||
from .evolution import SPOSEvolution | |||||
from .mutator import SPOSSupernetTrainingMutator | |||||
from .trainer import SPOSSupernetTrainer |