| @@ -0,0 +1,106 @@ | |||||
| """ DynamicConv """ | |||||
| import mindspore as ms | |||||
| from mindspore import nn | |||||
| from mindspore.common.initializer import initializer, HeNormal, HeUniform | |||||
| class Attention(nn.Cell): | |||||
| """Attention with temperature | |||||
| """ | |||||
| def __init__(self, in_channels, ratio, K, temperature=30, init_weight=True): | |||||
| super().__init__() | |||||
| self.avgpool = nn.AdaptiveAvgPool2d(1) | |||||
| self.temperature = temperature | |||||
| assert in_channels > ratio | |||||
| hidden_channels = in_channels // ratio | |||||
| self.net = nn.SequentialCell( | |||||
| nn.Conv2d(in_channels, hidden_channels, kernel_size=1, has_bias=False), | |||||
| nn.ReLU(), | |||||
| nn.Conv2d(hidden_channels, K, kernel_size=1, has_bias=False) | |||||
| ) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| if init_weight: | |||||
| self.apply(self.init_weight) | |||||
| def update_temperature(self): | |||||
| """update temperature | |||||
| """ | |||||
| if self.temperature > 1: | |||||
| self.temperature -= 1 | |||||
| def init_weight(self, cell): | |||||
| """init Conv2d and BatchNorm2d | |||||
| """ | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'), | |||||
| cell.weight.shape, cell.weight.dtype)) | |||||
| if cell.bias is not None: | |||||
| cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | |||||
| elif isinstance(cell, nn.BatchNorm2d): | |||||
| cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype)) | |||||
| cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) | |||||
| def construct(self, x): | |||||
| att = self.avgpool(x) | |||||
| att = self.net(att).view(x.shape[0], -1) | |||||
| return self.sigmoid(att) | |||||
| class DynamicConv(nn.Cell): | |||||
| """DynamicConv | |||||
| """ | |||||
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, | |||||
| dilation=1, groups=1, bias=True, K=4, temperature=30, ratio=4, init_weight=True): | |||||
| super().__init__() | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| self.dilation = dilation | |||||
| self.groups = groups | |||||
| self.bias = bias | |||||
| self.K = K | |||||
| self.attention = Attention(in_channels, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight) | |||||
| self.weight = ms.Parameter( | |||||
| ms.ops.randn((K, out_channels, in_channels // groups, kernel_size, kernel_size)), requires_grad=True) | |||||
| if bias: | |||||
| self.bias = ms.Parameter(ms.ops.randn((K, out_channels)), requires_grad=True) | |||||
| else: | |||||
| self.bias = None | |||||
| if init_weight: | |||||
| self.init_weight() | |||||
| def init_weight(self): | |||||
| """init weight with HeUniform | |||||
| """ | |||||
| for i in range(self.K): | |||||
| self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype) | |||||
| def construct(self, x): | |||||
| B, _, H, W = x.shape | |||||
| softmax_att = self.attention(x) | |||||
| x = x.view(1, -1, H, W) | |||||
| weight = self.weight.view(self.K, -1) | |||||
| aggregate_weight = ms.ops.mm(softmax_att, weight)\ | |||||
| .view(B*self.out_channels, self.in_channels//self.groups, self.kernel_size, self.kernel_size) | |||||
| if self.bias is not None: | |||||
| bias = self.bias.view(self.K, -1) | |||||
| aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1) | |||||
| output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, | |||||
| pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation) | |||||
| else: | |||||
| output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, | |||||
| pad_mode='pad', padding=self.padding, groups=self.groups*B, dilation=self.dilation) | |||||
| output = output.view(B, self.out_channels, H, W) | |||||
| return output | |||||
| if __name__ == "__main__": | |||||
| in_tensor = ms.ops.randn((2, 32, 64, 64)) | |||||
| cond = DynamicConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | |||||
| out = cond(in_tensor) | |||||
| print(out.shape) | |||||