Browse Source

add DynamicConv

pull/2/head
Huyf9 2 years ago
parent
commit
6c75fbf6d8
1 changed files with 106 additions and 0 deletions
  1. +106
    -0
      model/conv/DynamicConv.py

+ 106
- 0
model/conv/DynamicConv.py View File

@@ -0,0 +1,106 @@
""" DynamicConv """
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, HeUniform


class Attention(nn.Cell):
"""Attention with temperature
"""
def __init__(self, in_channels, ratio, K, temperature=30, init_weight=True):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.temperature = temperature
assert in_channels > ratio
hidden_channels = in_channels // ratio
self.net = nn.SequentialCell(
nn.Conv2d(in_channels, hidden_channels, kernel_size=1, has_bias=False),
nn.ReLU(),
nn.Conv2d(hidden_channels, K, kernel_size=1, has_bias=False)
)
self.sigmoid = nn.Sigmoid()
if init_weight:
self.apply(self.init_weight)

def update_temperature(self):
"""update temperature
"""
if self.temperature > 1:
self.temperature -= 1

def init_weight(self, cell):
"""init Conv2d and BatchNorm2d
"""
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))

def construct(self, x):
att = self.avgpool(x)
att = self.net(att).view(x.shape[0], -1)
return self.sigmoid(att)


class DynamicConv(nn.Cell):
"""DynamicConv
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0,
dilation=1, groups=1, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.K = K
self.attention = Attention(in_channels, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight)

self.weight = ms.Parameter(
ms.ops.randn((K, out_channels, in_channels // groups, kernel_size, kernel_size)), requires_grad=True)
if bias:
self.bias = ms.Parameter(ms.ops.randn((K, out_channels)), requires_grad=True)
else:
self.bias = None

if init_weight:
self.init_weight()

def init_weight(self):
"""init weight with HeUniform
"""
for i in range(self.K):
self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype)

def construct(self, x):
B, _, H, W = x.shape
softmax_att = self.attention(x)
x = x.view(1, -1, H, W)
weight = self.weight.view(self.K, -1)
aggregate_weight = ms.ops.mm(softmax_att, weight)\
.view(B*self.out_channels, self.in_channels//self.groups, self.kernel_size, self.kernel_size)

if self.bias is not None:
bias = self.bias.view(self.K, -1)
aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1)
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride,
pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation)
else:
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride,
pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation)
output = output.view(B, self.out_channels, H, W)
return output


if __name__ == "__main__":
in_tensor = ms.ops.randn((2, 32, 64, 64))
cond = DynamicConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
out = cond(in_tensor)
print(out.shape)

Loading…
Cancel
Save