|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- def upsample(x, size): return F.interpolate(
- x, size, mode='bilinear', align_corners=False)
-
-
- batchnorm_momentum = 0.01 / 2
-
-
- def get_n_params(parameters):
- pp = 0
- for p in parameters:
- nn = 1
- for s in list(p.size()):
- nn = nn * s
- pp += nn
- return pp
-
-
- class _BNReluConv(nn.Sequential):
- def __init__(self,
- num_maps_in,
- num_maps_out,
- k=3,
- batch_norm=True,
- bn_momentum=0.1,
- bias=False,
- dilation=1):
- super(_BNReluConv, self).__init__()
- if batch_norm:
- self.add_module('norm', nn.BatchNorm2d(
- num_maps_in, momentum=bn_momentum))
- self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
- padding = k // 2
- self.add_module('conv',
- nn.Conv2d(num_maps_in,
- num_maps_out,
- kernel_size=k,
- padding=padding,
- bias=bias,
- dilation=dilation))
-
-
- class _Upsample(nn.Module):
- def __init__(self,
- num_maps_in,
- skip_maps_in,
- num_maps_out,
- use_bn=True,
- k=3):
- super(_Upsample, self).__init__()
- self.bottleneck = _BNReluConv(
- skip_maps_in, num_maps_in, k=1, batch_norm=use_bn)
- self.blend_conv = _BNReluConv(
- num_maps_in, num_maps_out, k=k, batch_norm=use_bn)
-
- def forward(self, x, skip):
- skip = self.bottleneck.forward(skip)
- skip_size = skip.size()[2:4]
- x = upsample(x, skip_size)
- x = x + skip
- x = self.blend_conv.forward(x)
- return x
-
-
- class SpatialPyramidPooling(nn.Module):
- def __init__(self,
- num_maps_in,
- num_levels,
- bt_size=512,
- level_size=128,
- out_size=128,
- grids=(6, 3, 2, 1),
- square_grid=False,
- bn_momentum=0.1,
- use_bn=True):
- super(SpatialPyramidPooling, self).__init__()
- self.grids = grids
- self.square_grid = square_grid
- self.spp = nn.Sequential()
- self.spp.add_module('spp_bn',
- _BNReluConv(num_maps_in,
- bt_size,
- k=1,
- bn_momentum=bn_momentum,
- batch_norm=use_bn))
- num_features = bt_size
- final_size = num_features
- for i in range(num_levels):
- final_size += level_size
- self.spp.add_module('spp' + str(i),
- _BNReluConv(num_features,
- level_size,
- k=1,
- bn_momentum=bn_momentum,
- batch_norm=use_bn))
- self.spp.add_module('spp_fuse',
- _BNReluConv(final_size,
- out_size,
- k=1,
- bn_momentum=bn_momentum,
- batch_norm=use_bn))
-
- def forward(self, x):
- levels = []
- target_size = x.size()[2:4]
-
- ar = target_size[1] / target_size[0]
-
- x = self.spp[0].forward(x)
- levels.append(x)
- num = len(self.spp) - 1
-
- for i in range(1, num):
- if not self.square_grid:
- grid_size = (self.grids[i - 1],
- max(1, round(ar * self.grids[i - 1])))
- x_pooled = F.adaptive_avg_pool2d(x, grid_size)
- else:
- x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
- level = self.spp[i].forward(x_pooled)
-
- level = upsample(level, target_size)
- levels.append(level)
- x = torch.cat(levels, 1)
- x = self.spp[-1].forward(x)
- return x
-
-
- class _UpsampleBlend(nn.Module):
- def __init__(self, num_features, use_bn=True):
- super(_UpsampleBlend, self).__init__()
- self.blend_conv = _BNReluConv(
- num_features, num_features, k=3, batch_norm=use_bn)
-
- def forward(self, x, skip):
- skip_size = skip.size()[2:4]
- x = upsample(x, skip_size)
- x = x + skip
- x = self.blend_conv.forward(x)
- return x
|