|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- import torch.nn as nn
-
- import sys
- sys.path.append('..'+ '/' + '..')
- from pytorch import mutables # LayerChoice, InputChoice, MutableScope
- from ops import FactorizedReduce, ConvBranch, PoolBranch
-
-
- class ENASLayer(mutables.MutableScope):
-
- def __init__(self, key, prev_labels, in_filters, out_filters):
- super().__init__(key)
- self.in_filters = in_filters
- self.out_filters = out_filters
- self.mutable = mutables.LayerChoice([
- ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
- ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
- ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
- ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
- PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
- PoolBranch('max', in_filters, out_filters, 3, 1, 1)
- ])
- if len(prev_labels) > 0:
- self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
- else:
- self.skipconnect = None
- self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
-
- def forward(self, prev_layers):
- out = self.mutable(prev_layers[-1])
- if self.skipconnect is not None:
- connection = self.skipconnect(prev_layers[:-1])
- if connection is not None:
- out += connection
- return self.batch_norm(out)
-
-
- class GeneralNetwork(nn.Module):
- def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
- dropout_rate=0.0):
- super().__init__()
- self.num_layers = num_layers
- self.num_classes = num_classes
- self.out_filters = out_filters
-
- self.stem = nn.Sequential(
- nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
- nn.BatchNorm2d(out_filters)
- )
-
- pool_distance = self.num_layers // 3
- self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
- self.dropout_rate = dropout_rate
- self.dropout = nn.Dropout(self.dropout_rate)
- self.layers = nn.ModuleList()
- self.pool_layers = nn.ModuleList()
- labels = []
- for layer_id in range(self.num_layers):
- labels.append("layer_{}".format(layer_id))
- if layer_id in self.pool_layers_idx:
- self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
- self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
-
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.dense = nn.Linear(self.out_filters, self.num_classes)
-
- def forward(self, x):
- bs = x.size(0)
- cur = self.stem(x)
-
- layers = [cur]
-
- for layer_id in range(self.num_layers):
- cur = self.layers[layer_id](layers)
- layers.append(cur)
- if layer_id in self.pool_layers_idx:
- for i, layer in enumerate(layers):
- layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
- cur = layers[-1]
-
- cur = self.gap(cur).view(bs, -1)
- cur = self.dropout(cur)
- logits = self.dense(cur)
- return logits
|