| @@ -1,58 +0,0 @@ | |||||
| """ | |||||
| MindSpore implementation of 'ShuffleAttention' | |||||
| Refer to "SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS" | |||||
| """ | |||||
| import mindspore as ms | |||||
| from mindspore import nn | |||||
| class ShuffleAttention(nn.Cell): | |||||
| """ ShuffleAttention """ | |||||
| def __init__(self, channels=512, G=8): | |||||
| super().__init__() | |||||
| self.G = G | |||||
| self.channels = channels | |||||
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||||
| self.gn = nn.GroupNorm(channels // (2 * G), channels // (2 * G)) | |||||
| self.cweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1))) | |||||
| self.cbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1))) | |||||
| self.sweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1))) | |||||
| self.sbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1))) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| @staticmethod | |||||
| def channel_shuffle(x, groups): | |||||
| """ channel shuffle """ | |||||
| B, _, H, W = x.shape | |||||
| x = x.reshape(B, groups, -1, H, W) | |||||
| x = x.permute(0, 2, 1, 3, 4) | |||||
| x = x.reshape(B, -1, H, W) | |||||
| return x | |||||
| def construct(self, x): | |||||
| B, _, H, W = x.shape | |||||
| x = x.view(B * self.G, -1, H, W) | |||||
| x_0, x_1 = x.chunk(2, axis=1) | |||||
| # channel attention | |||||
| x_channel = self.avg_pool(x_0) | |||||
| x_channel = self.cweight * x_channel + self.cbias | |||||
| x_channel = x_0 * self.sigmoid(x_channel) | |||||
| # spatial attention | |||||
| x_spatial = self.gn(x_1) | |||||
| x_spatial = self.sweight * x_spatial + self.sbias | |||||
| x_spatial = x_1 * self.sigmoid(x_spatial) | |||||
| out = ms.ops.cat((x_channel, x_spatial), axis=1) | |||||
| out = out.view(B, -1, H, W) | |||||
| out = self.channel_shuffle(out, 2) | |||||
| return out | |||||
| if __name__ == '__main__': | |||||
| dummy_input = ms.ops.randn(50, 512, 7, 7) | |||||
| se = ShuffleAttention(channels=512, G=8) | |||||
| output = se(dummy_input) | |||||
| print(output.shape) | |||||