|
- import os
- import argparse
- import logging
- import sys
-
- from collections import OrderedDict
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torchvision import datasets, transforms
-
- from pytorch.mutables import LayerChoice, InputChoice
- from mutator import ClassicMutator
- import numpy as np
-
- class Net(nn.Module):
- def __init__(self, hidden_size):
- super(Net, self).__init__()
- # two options of conv1
- self.conv1 = LayerChoice(OrderedDict([
- ("conv5x5", nn.Conv2d(1, 20, 5, 1)),
- ("conv3x3", nn.Conv2d(1, 20, 3, 1))
- ]), key='conv1')
- # two options of mid_conv
- self.mid_conv = LayerChoice(OrderedDict([
- ("conv3x3",nn.Conv2d(20, 20, 3, 1, padding=1)),
- ("conv5x5",nn.Conv2d(20, 20, 5, 1, padding=2))
- ]), key='mid_conv')
- self.conv2 = nn.Conv2d(20, 50, 5, 1)
- self.fc1 = nn.Linear(4*4*50, hidden_size)
- self.fc2 = nn.Linear(hidden_size, 10)
- # skip connection over mid_conv
- self.input_switch = InputChoice(n_candidates=2,
- n_chosen=1,
- key='skip')
-
- def forward(self, x):
- x = F.relu(self.conv1(x))
- x = F.max_pool2d(x, 2, 2)
- old_x = x
- x = F.relu(self.mid_conv(x))
- zero_x = torch.zeros_like(old_x)
- skip_x = self.input_switch([zero_x, old_x])
- x = torch.add(x, skip_x)
- x = F.relu(self.conv2(x))
- x = F.max_pool2d(x, 2, 2)
- x = x.view(-1, 4*4*50)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return F.log_softmax(x, dim=1)
|