You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

res_net.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. '''
  2. Properly implemented ResNet-s for CIFAR10 as described in paper [1].
  3. The implementation and structure of this file is hugely influenced by [2]
  4. which is implemented for ImageNet and doesn't have option A for identity.
  5. Moreover, most of the implementations on the web is copy-paste from
  6. torchvision's resnet and has wrong number of params.
  7. Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
  8. number of layers and parameters:
  9. name | layers | params
  10. ResNet20 | 20 | 0.27M
  11. ResNet32 | 32 | 0.46M
  12. ResNet44 | 44 | 0.66M
  13. ResNet56 | 56 | 0.85M
  14. ResNet110 | 110 | 1.7M
  15. ResNet1202| 1202 | 19.4m
  16. which this implementation indeed has.
  17. Reference:
  18. [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
  19. Deep Residual Learning for Image Recognition. arXiv:1512.03385
  20. [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
  21. If you use this implementation in you work, please don't forget to mention the
  22. author, Yerlan Idelbayev.
  23. '''
  24. import math
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. import torch.nn.init as init
  29. from torch.autograd import Variable
  30. __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
  31. def _weights_init(m):
  32. if isinstance(m, nn.Linear):
  33. nn.init.kaiming_normal_(m.weight)
  34. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
  35. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  36. nn.init.uniform_(m.bias, -bound, bound)
  37. elif isinstance(m, nn.Conv2d):
  38. nn.init.kaiming_normal_(m.weight)
  39. class LambdaLayer(nn.Module):
  40. def __init__(self, lambd):
  41. super(LambdaLayer, self).__init__()
  42. self.lambd = lambd
  43. def forward(self, x):
  44. return self.lambd(x)
  45. class BasicBlock(nn.Module):
  46. expansion = 1
  47. def __init__(self, in_planes, planes, stride=1, option='A'):
  48. super(BasicBlock, self).__init__()
  49. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  50. self.bn1 = nn.BatchNorm2d(planes)
  51. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  52. self.bn2 = nn.BatchNorm2d(planes)
  53. self.shortcut = nn.Sequential()
  54. if stride != 1 or in_planes != planes:
  55. if option == 'A':
  56. """
  57. For CIFAR10 ResNet paper uses option A.
  58. """
  59. self.shortcut = LambdaLayer(lambda x:
  60. F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant",
  61. 0))
  62. elif option == 'B':
  63. self.shortcut = nn.Sequential(
  64. nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
  65. nn.BatchNorm2d(self.expansion * planes)
  66. )
  67. def forward(self, x):
  68. out = F.relu(self.bn1(self.conv1(x)))
  69. out = self.bn2(self.conv2(out))
  70. out += self.shortcut(x)
  71. out = F.relu(out)
  72. return out
  73. class ResNet(nn.Module):
  74. def __init__(self, block, num_blocks, num_classes=10):
  75. super(ResNet, self).__init__()
  76. self.in_planes = 16
  77. self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  78. self.bn1 = nn.BatchNorm2d(16)
  79. self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
  80. self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
  81. self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
  82. self.linear = nn.Linear(64, num_classes)
  83. self.apply(_weights_init)
  84. def _make_layer(self, block, planes, num_blocks, stride):
  85. strides = [stride] + [1] * (num_blocks - 1)
  86. layers = []
  87. for stride in strides:
  88. layers.append(block(self.in_planes, planes, stride))
  89. self.in_planes = planes * block.expansion
  90. return nn.Sequential(*layers)
  91. def forward(self, x):
  92. out = F.relu(self.bn1(self.conv1(x)))
  93. out = self.layer1(out)
  94. out = self.layer2(out)
  95. out = self.layer3(out)
  96. out = F.avg_pool2d(out, out.size()[3])
  97. out = out.view(out.size(0), -1)
  98. out = self.linear(out)
  99. return out
  100. def resnet20():
  101. return ResNet(BasicBlock, [3, 3, 3])
  102. def resnet32():
  103. return ResNet(BasicBlock, [5, 5, 5])
  104. def resnet44():
  105. return ResNet(BasicBlock, [7, 7, 7])
  106. def resnet56():
  107. return ResNet(BasicBlock, [9, 9, 9], num_classes=100)
  108. def resnet110():
  109. return ResNet(BasicBlock, [18, 18, 18])
  110. def resnet1202():
  111. return ResNet(BasicBlock, [200, 200, 200])
  112. def test(net):
  113. import numpy as np
  114. total_params = 0
  115. for x in filter(lambda p: p.requires_grad, net.parameters()):
  116. total_params += np.prod(x.data.numpy().shape)
  117. print("Total number of params", total_params)
  118. print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size()) > 1, net.parameters()))))
  119. def accuracy(output, target, topk=(1,)):
  120. """Computes the precision@k for the specified values of k"""
  121. maxk = max(topk)
  122. batch_size = target.size(0)
  123. _, pred = output.topk(maxk, 1, True, True)
  124. pred = pred.t()
  125. correct = pred.eq(target.view(1, -1).expand_as(pred))
  126. res = []
  127. for k in topk:
  128. correct_k = correct[:k].contiguous().view(-1).float().sum(0)
  129. res.append(correct_k.mul_(100.0 / batch_size))
  130. return res
  131. if __name__ == "__main__":
  132. for net_name in __all__:
  133. if net_name.startswith('resnet'):
  134. print(net_name)
  135. test(globals()[net_name]())
  136. print()

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)