You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def upsample(x, size): return F.interpolate(
  5. x, size, mode='bilinear', align_corners=False)
  6. batchnorm_momentum = 0.01 / 2
  7. def get_n_params(parameters):
  8. pp = 0
  9. for p in parameters:
  10. nn = 1
  11. for s in list(p.size()):
  12. nn = nn * s
  13. pp += nn
  14. return pp
  15. class _BNReluConv(nn.Sequential):
  16. def __init__(self,
  17. num_maps_in,
  18. num_maps_out,
  19. k=3,
  20. batch_norm=True,
  21. bn_momentum=0.1,
  22. bias=False,
  23. dilation=1):
  24. super(_BNReluConv, self).__init__()
  25. if batch_norm:
  26. self.add_module('norm', nn.BatchNorm2d(
  27. num_maps_in, momentum=bn_momentum))
  28. self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
  29. padding = k // 2
  30. self.add_module('conv',
  31. nn.Conv2d(num_maps_in,
  32. num_maps_out,
  33. kernel_size=k,
  34. padding=padding,
  35. bias=bias,
  36. dilation=dilation))
  37. class _Upsample(nn.Module):
  38. def __init__(self,
  39. num_maps_in,
  40. skip_maps_in,
  41. num_maps_out,
  42. use_bn=True,
  43. k=3):
  44. super(_Upsample, self).__init__()
  45. self.bottleneck = _BNReluConv(
  46. skip_maps_in, num_maps_in, k=1, batch_norm=use_bn)
  47. self.blend_conv = _BNReluConv(
  48. num_maps_in, num_maps_out, k=k, batch_norm=use_bn)
  49. def forward(self, x, skip):
  50. skip = self.bottleneck.forward(skip)
  51. skip_size = skip.size()[2:4]
  52. x = upsample(x, skip_size)
  53. x = x + skip
  54. x = self.blend_conv.forward(x)
  55. return x
  56. class SpatialPyramidPooling(nn.Module):
  57. def __init__(self,
  58. num_maps_in,
  59. num_levels,
  60. bt_size=512,
  61. level_size=128,
  62. out_size=128,
  63. grids=(6, 3, 2, 1),
  64. square_grid=False,
  65. bn_momentum=0.1,
  66. use_bn=True):
  67. super(SpatialPyramidPooling, self).__init__()
  68. self.grids = grids
  69. self.square_grid = square_grid
  70. self.spp = nn.Sequential()
  71. self.spp.add_module('spp_bn',
  72. _BNReluConv(num_maps_in,
  73. bt_size,
  74. k=1,
  75. bn_momentum=bn_momentum,
  76. batch_norm=use_bn))
  77. num_features = bt_size
  78. final_size = num_features
  79. for i in range(num_levels):
  80. final_size += level_size
  81. self.spp.add_module('spp' + str(i),
  82. _BNReluConv(num_features,
  83. level_size,
  84. k=1,
  85. bn_momentum=bn_momentum,
  86. batch_norm=use_bn))
  87. self.spp.add_module('spp_fuse',
  88. _BNReluConv(final_size,
  89. out_size,
  90. k=1,
  91. bn_momentum=bn_momentum,
  92. batch_norm=use_bn))
  93. def forward(self, x):
  94. levels = []
  95. target_size = x.size()[2:4]
  96. ar = target_size[1] / target_size[0]
  97. x = self.spp[0].forward(x)
  98. levels.append(x)
  99. num = len(self.spp) - 1
  100. for i in range(1, num):
  101. if not self.square_grid:
  102. grid_size = (self.grids[i - 1],
  103. max(1, round(ar * self.grids[i - 1])))
  104. x_pooled = F.adaptive_avg_pool2d(x, grid_size)
  105. else:
  106. x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
  107. level = self.spp[i].forward(x_pooled)
  108. level = upsample(level, target_size)
  109. levels.append(level)
  110. x = torch.cat(levels, 1)
  111. x = self.spp[-1].forward(x)
  112. return x
  113. class _UpsampleBlend(nn.Module):
  114. def __init__(self, num_features, use_bn=True):
  115. super(_UpsampleBlend, self).__init__()
  116. self.blend_conv = _BNReluConv(
  117. num_features, num_features, k=3, batch_norm=use_bn)
  118. def forward(self, x, skip):
  119. skip_size = skip.size()[2:4]
  120. x = upsample(x, skip_size)
  121. x = x + skip
  122. x = self.blend_conv.forward(x)
  123. return x