You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vgg.py 1.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import torch.nn as nn
  3. cfg = {
  4. 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  5. 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  6. 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  7. 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  8. }
  9. class VGG(nn.Module):
  10. def __init__(self, vgg_name, num_class=10):
  11. super(VGG, self).__init__()
  12. self.features = self._make_layers(cfg[vgg_name])
  13. self.fc1 = nn.Linear(512, 4096)
  14. self.fc2 = nn.Linear(4096, 4096)
  15. self.classifier = nn.Linear(4096, num_class)
  16. def forward(self, x):
  17. out = self.features(x)
  18. out = out.view(out.size(0), -1)
  19. out = self.fc2(self.fc1(out))
  20. out = self.classifier(out)
  21. return out
  22. def _make_layers(self, cfg):
  23. layers = []
  24. in_channels = 3
  25. for x in cfg:
  26. if x == 'M':
  27. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  28. else:
  29. layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(x),
  31. nn.ReLU(inplace=True)]
  32. in_channels = x
  33. layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
  34. return nn.Sequential(*layers)
  35. def vgg16(num_class=10):
  36. return VGG('VGG16', num_class)
  37. def vgg19(num_class=10):
  38. return VGG('VGG19', num_class)