|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False)
- batchnorm_momentum = 0.01 / 2
-
-
- def get_n_params(parameters):
- pp = 0
- for p in parameters:
- nn = 1
- for s in list(p.size()):
- nn = nn * s
- pp += nn
- return pp
-
-
- class _BNReluConv(nn.Sequential):
- def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1):
- super(_BNReluConv, self).__init__()
- if batch_norm:
- self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum))
- self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
- padding = k // 2 # same conv
- self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out,
- kernel_size=k, padding=padding, bias=bias, dilation=dilation))
-
-
- class _Upsample(nn.Module):
- def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3):
- super(_Upsample, self).__init__()
- self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn)
- self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn)
-
- def forward(self, x, skip):
- skip = self.bottleneck.forward(skip)
- skip_size = skip.size()[2:4]
- x = upsample(x, skip_size)
- x = x + skip
- x = self.blend_conv.forward(x)
- return x
-
-
- class SpatialPyramidPooling(nn.Module):
- def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128,
- grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True):
- super(SpatialPyramidPooling, self).__init__()
- self.grids = grids
- self.square_grid = square_grid
- self.spp = nn.Sequential()
- self.spp.add_module('spp_bn',
- _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
- num_features = bt_size
- final_size = num_features
- for i in range(num_levels):
- final_size += level_size
- self.spp.add_module('spp' + str(i),
- _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
- self.spp.add_module('spp_fuse',
- _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
-
- def forward(self, x):
- levels = []
- target_size = x.size()[2:4]
-
- ar = target_size[1] / target_size[0]
-
- x = self.spp[0].forward(x)
- levels.append(x)
- num = len(self.spp) - 1
-
- for i in range(1, num):
- if not self.square_grid:
- grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1])))
- x_pooled = F.adaptive_avg_pool2d(x, grid_size)
- else:
- x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
- level = self.spp[i].forward(x_pooled)
-
- level = upsample(level, target_size)
- levels.append(level)
- x = torch.cat(levels, 1)
- x = self.spp[-1].forward(x)
- return x
-
-
- class _UpsampleBlend(nn.Module):
- def __init__(self, num_features, use_bn=True):
- super(_UpsampleBlend, self).__init__()
- self.blend_conv = _BNReluConv(num_features, num_features, k=3, batch_norm=use_bn)
-
- def forward(self, x, skip):
- skip_size = skip.size()[2:4]
- x = upsample(x, skip_size)
- x = x + skip
- x = self.blend_conv.forward(x)
- return x
|