You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vit_seg_modeling_resnet_skip.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import math
  2. from os.path import join as pjoin
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. def np2th(weights, conv=False):
  8. """Possibly convert HWIO to OIHW."""
  9. if conv:
  10. weights = weights.transpose([3, 2, 0, 1])
  11. return torch.from_numpy(weights)
  12. class StdConv2d(nn.Conv2d):
  13. def forward(self, x):
  14. w = self.weight
  15. v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
  16. w = (w - m) / torch.sqrt(v + 1e-5)
  17. return F.conv2d(x, w, self.bias, self.stride, self.padding,
  18. self.dilation, self.groups)
  19. def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  20. return StdConv2d(cin, cout, kernel_size=3, stride=stride,
  21. padding=1, bias=bias, groups=groups)
  22. def conv1x1(cin, cout, stride=1, bias=False):
  23. return StdConv2d(cin, cout, kernel_size=1, stride=stride,
  24. padding=0, bias=bias)
  25. class PreActBottleneck(nn.Module):
  26. """Pre-activation (v2) bottleneck block.
  27. """
  28. def __init__(self, cin, cout=None, cmid=None, stride=1):
  29. super().__init__()
  30. cout = cout or cin
  31. cmid = cmid or cout//4
  32. self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
  33. self.conv1 = conv1x1(cin, cmid, bias=False)
  34. self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
  35. self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
  36. self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
  37. self.conv3 = conv1x1(cmid, cout, bias=False)
  38. self.relu = nn.ReLU(inplace=True)
  39. if (stride != 1 or cin != cout):
  40. # Projection also with pre-activation according to paper.
  41. self.downsample = conv1x1(cin, cout, stride, bias=False)
  42. self.gn_proj = nn.GroupNorm(cout, cout)
  43. def forward(self, x):
  44. # Residual branch
  45. residual = x
  46. if hasattr(self, 'downsample'):
  47. residual = self.downsample(x)
  48. residual = self.gn_proj(residual)
  49. # Unit's branch
  50. y = self.relu(self.gn1(self.conv1(x)))
  51. y = self.relu(self.gn2(self.conv2(y)))
  52. y = self.gn3(self.conv3(y))
  53. y = self.relu(residual + y)
  54. return y
  55. def load_from(self, weights, n_block, n_unit):
  56. conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
  57. conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
  58. conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
  59. gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
  60. gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
  61. gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
  62. gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
  63. gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
  64. gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
  65. self.conv1.weight.copy_(conv1_weight)
  66. self.conv2.weight.copy_(conv2_weight)
  67. self.conv3.weight.copy_(conv3_weight)
  68. self.gn1.weight.copy_(gn1_weight.view(-1))
  69. self.gn1.bias.copy_(gn1_bias.view(-1))
  70. self.gn2.weight.copy_(gn2_weight.view(-1))
  71. self.gn2.bias.copy_(gn2_bias.view(-1))
  72. self.gn3.weight.copy_(gn3_weight.view(-1))
  73. self.gn3.bias.copy_(gn3_bias.view(-1))
  74. if hasattr(self, 'downsample'):
  75. proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
  76. proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
  77. proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
  78. self.downsample.weight.copy_(proj_conv_weight)
  79. self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
  80. self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
  81. class ResNetV2(nn.Module):
  82. """Implementation of Pre-activation (v2) ResNet mode."""
  83. def __init__(self, block_units, width_factor):
  84. super().__init__()
  85. width = int(64 * width_factor)
  86. self.width = width
  87. self.root = nn.Sequential(OrderedDict([
  88. ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
  89. ('gn', nn.GroupNorm(32, width, eps=1e-6)),
  90. ('relu', nn.ReLU(inplace=True)),
  91. # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
  92. ]))
  93. self.body = nn.Sequential(OrderedDict([
  94. ('block1', nn.Sequential(OrderedDict(
  95. [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
  96. [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
  97. ))),
  98. ('block2', nn.Sequential(OrderedDict(
  99. [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
  100. [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
  101. ))),
  102. ('block3', nn.Sequential(OrderedDict(
  103. [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
  104. [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
  105. ))),
  106. ]))
  107. def forward(self, x):
  108. features = []
  109. b, c, in_size, _ = x.size()
  110. x = self.root(x)
  111. features.append(x)
  112. x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
  113. for i in range(len(self.body)-1):
  114. x = self.body[i](x)
  115. right_size = int(in_size / 4 / (i+1))
  116. if x.size()[2] != right_size:
  117. pad = right_size - x.size()[2]
  118. assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
  119. feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
  120. feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
  121. else:
  122. feat = x
  123. features.append(feat)
  124. x = self.body[-1](x)
  125. return x, features[::-1]

网络代码复现