From 6c75fbf6d89671573f357c235aa4238d508850f9 Mon Sep 17 00:00:00 2001 From: Huyf9 <568558927@qq.com> Date: Wed, 26 Jul 2023 22:23:56 +0800 Subject: [PATCH] add DynamicConv --- model/conv/DynamicConv.py | 106 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 model/conv/DynamicConv.py diff --git a/model/conv/DynamicConv.py b/model/conv/DynamicConv.py new file mode 100644 index 0000000..ecb84a3 --- /dev/null +++ b/model/conv/DynamicConv.py @@ -0,0 +1,106 @@ +""" DynamicConv """ +import mindspore as ms +from mindspore import nn +from mindspore.common.initializer import initializer, HeNormal, HeUniform + + +class Attention(nn.Cell): + """Attention with temperature + """ + def __init__(self, in_channels, ratio, K, temperature=30, init_weight=True): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.temperature = temperature + assert in_channels > ratio + hidden_channels = in_channels // ratio + self.net = nn.SequentialCell( + nn.Conv2d(in_channels, hidden_channels, kernel_size=1, has_bias=False), + nn.ReLU(), + nn.Conv2d(hidden_channels, K, kernel_size=1, has_bias=False) + ) + self.sigmoid = nn.Sigmoid() + if init_weight: + self.apply(self.init_weight) + + def update_temperature(self): + """update temperature + """ + if self.temperature > 1: + self.temperature -= 1 + + def init_weight(self, cell): + """init Conv2d and BatchNorm2d + """ + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype)) + cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) + + def construct(self, x): + att = self.avgpool(x) + att = self.net(att).view(x.shape[0], -1) + return self.sigmoid(att) + + +class DynamicConv(nn.Cell): + """DynamicConv + """ + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, + dilation=1, groups=1, bias=True, K=4, temperature=30, ratio=4, init_weight=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.bias = bias + self.K = K + self.attention = Attention(in_channels, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight) + + self.weight = ms.Parameter( + ms.ops.randn((K, out_channels, in_channels // groups, kernel_size, kernel_size)), requires_grad=True) + if bias: + self.bias = ms.Parameter(ms.ops.randn((K, out_channels)), requires_grad=True) + else: + self.bias = None + + if init_weight: + self.init_weight() + + def init_weight(self): + """init weight with HeUniform + """ + for i in range(self.K): + self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype) + + def construct(self, x): + B, _, H, W = x.shape + softmax_att = self.attention(x) + x = x.view(1, -1, H, W) + weight = self.weight.view(self.K, -1) + aggregate_weight = ms.ops.mm(softmax_att, weight)\ + .view(B*self.out_channels, self.in_channels//self.groups, self.kernel_size, self.kernel_size) + + if self.bias is not None: + bias = self.bias.view(self.K, -1) + aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1) + output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, + pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation) + else: + output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, + pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation) + output = output.view(B, self.out_channels, H, W) + return output + + +if __name__ == "__main__": + in_tensor = ms.ops.randn((2, 32, 64, 64)) + cond = DynamicConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) + out = cond(in_tensor) + print(out.shape)