From c6c7639631f8c5a7eac0be6de98a51f645bf0899 Mon Sep 17 00:00:00 2001 From: limingjuan <3151508285@qq.com> Date: Mon, 18 Sep 2023 12:15:12 +0800 Subject: [PATCH] Delete 'model/attention/ShuffleAttention.py' --- model/attention/ShuffleAttention.py | 58 ----------------------------- 1 file changed, 58 deletions(-) delete mode 100644 model/attention/ShuffleAttention.py diff --git a/model/attention/ShuffleAttention.py b/model/attention/ShuffleAttention.py deleted file mode 100644 index fe6da58..0000000 --- a/model/attention/ShuffleAttention.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -MindSpore implementation of 'ShuffleAttention' -Refer to "SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS" -""" -import mindspore as ms -from mindspore import nn - - -class ShuffleAttention(nn.Cell): - """ ShuffleAttention """ - def __init__(self, channels=512, G=8): - super().__init__() - self.G = G - self.channels = channels - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.gn = nn.GroupNorm(channels // (2 * G), channels // (2 * G)) - self.cweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1))) - self.cbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1))) - self.sweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1))) - self.sbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1))) - self.sigmoid = nn.Sigmoid() - - @staticmethod - def channel_shuffle(x, groups): - """ channel shuffle """ - B, _, H, W = x.shape - x = x.reshape(B, groups, -1, H, W) - x = x.permute(0, 2, 1, 3, 4) - x = x.reshape(B, -1, H, W) - return x - - def construct(self, x): - B, _, H, W = x.shape - x = x.view(B * self.G, -1, H, W) - - x_0, x_1 = x.chunk(2, axis=1) - - # channel attention - x_channel = self.avg_pool(x_0) - x_channel = self.cweight * x_channel + self.cbias - x_channel = x_0 * self.sigmoid(x_channel) - - # spatial attention - x_spatial = self.gn(x_1) - x_spatial = self.sweight * x_spatial + self.sbias - x_spatial = x_1 * self.sigmoid(x_spatial) - - out = ms.ops.cat((x_channel, x_spatial), axis=1) - out = out.view(B, -1, H, W) - out = self.channel_shuffle(out, 2) - return out - - -if __name__ == '__main__': - dummy_input = ms.ops.randn(50, 512, 7, 7) - se = ShuffleAttention(channels=512, G=8) - output = se(dummy_input) - print(output.shape)