You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 3.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False)
  5. batchnorm_momentum = 0.01 / 2
  6. def get_n_params(parameters):
  7. pp = 0
  8. for p in parameters:
  9. nn = 1
  10. for s in list(p.size()):
  11. nn = nn * s
  12. pp += nn
  13. return pp
  14. class _BNReluConv(nn.Sequential):
  15. def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1):
  16. super(_BNReluConv, self).__init__()
  17. if batch_norm:
  18. self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum))
  19. self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
  20. padding = k // 2 # same conv
  21. self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out,
  22. kernel_size=k, padding=padding, bias=bias, dilation=dilation))
  23. class _Upsample(nn.Module):
  24. def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3):
  25. super(_Upsample, self).__init__()
  26. self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn)
  27. self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn)
  28. def forward(self, x, skip):
  29. skip = self.bottleneck.forward(skip)
  30. skip_size = skip.size()[2:4]
  31. x = upsample(x, skip_size)
  32. x = x + skip
  33. x = self.blend_conv.forward(x)
  34. return x
  35. class SpatialPyramidPooling(nn.Module):
  36. def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128,
  37. grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True):
  38. super(SpatialPyramidPooling, self).__init__()
  39. self.grids = grids
  40. self.square_grid = square_grid
  41. self.spp = nn.Sequential()
  42. self.spp.add_module('spp_bn',
  43. _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
  44. num_features = bt_size
  45. final_size = num_features
  46. for i in range(num_levels):
  47. final_size += level_size
  48. self.spp.add_module('spp' + str(i),
  49. _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
  50. self.spp.add_module('spp_fuse',
  51. _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
  52. def forward(self, x):
  53. levels = []
  54. target_size = x.size()[2:4]
  55. ar = target_size[1] / target_size[0]
  56. x = self.spp[0].forward(x)
  57. levels.append(x)
  58. num = len(self.spp) - 1
  59. for i in range(1, num):
  60. if not self.square_grid:
  61. grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1])))
  62. x_pooled = F.adaptive_avg_pool2d(x, grid_size)
  63. else:
  64. x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
  65. level = self.spp[i].forward(x_pooled)
  66. level = upsample(level, target_size)
  67. levels.append(level)
  68. x = torch.cat(levels, 1)
  69. x = self.spp[-1].forward(x)
  70. return x
  71. class _UpsampleBlend(nn.Module):
  72. def __init__(self, num_features, use_bn=True):
  73. super(_UpsampleBlend, self).__init__()
  74. self.blend_conv = _BNReluConv(num_features, num_features, k=3, batch_norm=use_bn)
  75. def forward(self, x, skip):
  76. skip_size = skip.size()[2:4]
  77. x = upsample(x, skip_size)
  78. x = x + skip
  79. x = self.blend_conv.forward(x)
  80. return x