import torch import torch.nn as nn cfg = { 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } class VGG(nn.Module): def __init__(self, vgg_name, num_class=10): super(VGG, self).__init__() self.features = self._make_layers(cfg[vgg_name]) self.fc1 = nn.Linear(512, 4096) self.fc2 = nn.Linear(4096, 4096) self.classifier = nn.Linear(4096, num_class) def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.fc2(self.fc1(out)) out = self.classifier(out) return out def _make_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers) def vgg16(num_class=10): return VGG('VGG16', num_class) def vgg19(num_class=10): return VGG('VGG19', num_class)