""" CondConv """ import mindspore as ms from mindspore import nn from mindspore.common.initializer import initializer, HeNormal, HeUniform class Attention(nn.Cell): """ Attnetion """ def __init__(self, in_channels, K, init_weight=True): super().__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1) self.sigmoid = nn.Sigmoid() if init_weight: self.apply(self.init_weights) def init_weights(self, cell): """ initialize weights """ if isinstance(cell, nn.Conv2d): cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) elif isinstance(cell, nn.BatchNorm2d): cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype)) cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) def construct(self, x): att = self.avgpool(x) att = self.net(att).view(x.shape[0], -1) return self.sigmoid(att) class CondConv(nn.Cell): """ CondConv """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, dilation=1, groups=1, bias=True, K=4, init_weight=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.bias = bias self.K = K self.attention = Attention(in_channels, K) self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size), requires_grad=True) if bias: self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True) else: self.bias = None if init_weight: self.init_weights() def init_weights(self): """ initialize weights """ for i in range(self.K): self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype) def construct(self, x): B, _, H, W = x.shape softmax_att = self.attention(x) x = x.view(1, -1, H, W) weight = self.weight.view(self.K, -1) aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups, self.kernel_size, self.kernel_size) if self.bias: bias = self.bias.view(self.K, -1) aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1) output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad", padding=self.padding, dilation=self.dilation, groups=self.groups * B) else: output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad", padding=self.padding, dilation=self.dilation, groups=self.groups * B) output = output.view(B, self.out_channels, H, W) return output if __name__ == "__main__": in_tensor = ms.ops.randn((2, 32, 64, 64)) cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) out = cconv(in_tensor) print(out.shape)