Browse Source

Delete 'model/attention/ShuffleAttention.py'

v1
limingjuan 2 years ago
parent
commit
c6c7639631
1 changed files with 0 additions and 58 deletions
  1. +0
    -58
      model/attention/ShuffleAttention.py

+ 0
- 58
model/attention/ShuffleAttention.py View File

@@ -1,58 +0,0 @@
"""
MindSpore implementation of 'ShuffleAttention'
Refer to "SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"
"""
import mindspore as ms
from mindspore import nn


class ShuffleAttention(nn.Cell):
""" ShuffleAttention """
def __init__(self, channels=512, G=8):
super().__init__()
self.G = G
self.channels = channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(channels // (2 * G), channels // (2 * G))
self.cweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1)))
self.cbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1)))
self.sweight = ms.Parameter(ms.ops.zeros((1, channels // (2 * G), 1, 1)))
self.sbias = ms.Parameter(ms.ops.ones((1, channels // (2 * G), 1, 1)))
self.sigmoid = nn.Sigmoid()

@staticmethod
def channel_shuffle(x, groups):
""" channel shuffle """
B, _, H, W = x.shape
x = x.reshape(B, groups, -1, H, W)
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(B, -1, H, W)
return x

def construct(self, x):
B, _, H, W = x.shape
x = x.view(B * self.G, -1, H, W)

x_0, x_1 = x.chunk(2, axis=1)

# channel attention
x_channel = self.avg_pool(x_0)
x_channel = self.cweight * x_channel + self.cbias
x_channel = x_0 * self.sigmoid(x_channel)

# spatial attention
x_spatial = self.gn(x_1)
x_spatial = self.sweight * x_spatial + self.sbias
x_spatial = x_1 * self.sigmoid(x_spatial)

out = ms.ops.cat((x_channel, x_spatial), axis=1)
out = out.view(B, -1, H, W)
out = self.channel_shuffle(out, 2)
return out


if __name__ == '__main__':
dummy_input = ms.ops.randn(50, 512, 7, 7)
se = ShuffleAttention(channels=512, G=8)
output = se(dummy_input)
print(output.shape)

Loading…
Cancel
Save