import torch import torch.nn as nn import torch.nn.functional as F upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False) batchnorm_momentum = 0.01 / 2 def get_n_params(parameters): pp = 0 for p in parameters: nn = 1 for s in list(p.size()): nn = nn * s pp += nn return pp class _BNReluConv(nn.Sequential): def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1): super(_BNReluConv, self).__init__() if batch_norm: self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum)) self.add_module('relu', nn.ReLU(inplace=batch_norm is True)) padding = k // 2 # same conv self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out, kernel_size=k, padding=padding, bias=bias, dilation=dilation)) class _Upsample(nn.Module): def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3): super(_Upsample, self).__init__() self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn) self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn) def forward(self, x, skip): skip = self.bottleneck.forward(skip) skip_size = skip.size()[2:4] x = upsample(x, skip_size) x = x + skip x = self.blend_conv.forward(x) return x class SpatialPyramidPooling(nn.Module): def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128, grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True): super(SpatialPyramidPooling, self).__init__() self.grids = grids self.square_grid = square_grid self.spp = nn.Sequential() self.spp.add_module('spp_bn', _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) num_features = bt_size final_size = num_features for i in range(num_levels): final_size += level_size self.spp.add_module('spp' + str(i), _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) self.spp.add_module('spp_fuse', _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) def forward(self, x): levels = [] target_size = x.size()[2:4] ar = target_size[1] / target_size[0] x = self.spp[0].forward(x) levels.append(x) num = len(self.spp) - 1 for i in range(1, num): if not self.square_grid: grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1]))) x_pooled = F.adaptive_avg_pool2d(x, grid_size) else: x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1]) level = self.spp[i].forward(x_pooled) level = upsample(level, target_size) levels.append(level) x = torch.cat(levels, 1) x = self.spp[-1].forward(x) return x class _UpsampleBlend(nn.Module): def __init__(self, num_features, use_bn=True): super(_UpsampleBlend, self).__init__() self.blend_conv = _BNReluConv(num_features, num_features, k=3, batch_norm=use_bn) def forward(self, x, skip): skip_size = skip.size()[2:4] x = upsample(x, skip_size) x = x + skip x = self.blend_conv.forward(x) return x