| @@ -1,100 +0,0 @@ | |||
| """ | |||
| MindSpore implementation of 'PolarizedSelfAttention' | |||
| Refer to "Polarized Self-Attention: Towards High-quality Pixel-wise Regression" | |||
| """ | |||
| import mindspore as ms | |||
| from mindspore import nn | |||
| class ParallelPolarizedSelfAttention(nn.Cell): | |||
| """ Parallel Polarized Self Attention """ | |||
| def __init__(self, channel=512): | |||
| super().__init__() | |||
| self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1)) | |||
| self.softmax_channel = nn.Softmax(1) | |||
| self.softmax_spatial = nn.Softmax(-1) | |||
| self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1)) | |||
| self.ln = nn.LayerNorm((channel,)) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.agp = nn.AdaptiveAvgPool2d((1, 1)) | |||
| def construct(self, x): | |||
| b, c, h, w = x.shape | |||
| # Channel-only Self-Attention | |||
| channel_wv = self.ch_wv(x) # bs,c//2,h,w | |||
| channel_wq = self.ch_wq(x) # bs,1,h,w | |||
| channel_wv = channel_wv.reshape(b, c // 2, -1) # bs,c//2,h*w | |||
| channel_wq = channel_wq.reshape(b, -1, 1) # bs,h*w,1 | |||
| channel_wq = self.softmax_channel(channel_wq) | |||
| channel_wz = ms.ops.matmul(channel_wv, channel_wq).unsqueeze(-1) # bs,c//2,1,1 | |||
| channel_weight = self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))).permute(0, 2, | |||
| 1).reshape( | |||
| b, c, 1, 1) # bs,c,1,1 | |||
| channel_out = channel_weight * x | |||
| # Spatial-only Self-Attention | |||
| spatial_wv = self.sp_wv(x) # bs,c//2,h,w | |||
| spatial_wq = self.sp_wq(x) # bs,c//2,h,w | |||
| spatial_wq = self.agp(spatial_wq) # bs,c//2,1,1 | |||
| spatial_wv = spatial_wv.reshape(b, c // 2, -1) # bs,c//2,h*w | |||
| spatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2) # bs,1,c//2 | |||
| spatial_wq = self.softmax_spatial(spatial_wq) | |||
| spatial_wz = ms.ops.matmul(spatial_wq, spatial_wv) # bs,1,h*w | |||
| spatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w)) # bs,1,h,w | |||
| spatial_out = spatial_weight * x | |||
| out = spatial_out + channel_out | |||
| return out | |||
| class SequentialPolarizedSelfAttention(nn.Cell): | |||
| """ Sequential Polarized Self Attention """ | |||
| def __init__(self, channel=512): | |||
| super().__init__() | |||
| self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1)) | |||
| self.softmax_channel = nn.Softmax(1) | |||
| self.softmax_spatial = nn.Softmax(-1) | |||
| self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1)) | |||
| self.ln = nn.LayerNorm((channel,)) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1)) | |||
| self.agp = nn.AdaptiveAvgPool2d((1, 1)) | |||
| def construct(self, x): | |||
| b, c, h, w = x.shape | |||
| # Channel-only Self-Attention | |||
| channel_wv = self.ch_wv(x) # bs,c//2,h,w | |||
| channel_wq = self.ch_wq(x) # bs,1,h,w | |||
| channel_wv = channel_wv.reshape(b, c // 2, -1) # bs,c//2,h*w | |||
| channel_wq = channel_wq.reshape(b, -1, 1) # bs,h*w,1 | |||
| channel_wq = self.softmax_channel(channel_wq) | |||
| channel_wz = ms.ops.matmul(channel_wv, channel_wq).unsqueeze(-1) # bs,c//2,1,1 | |||
| channel_weight = self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))).permute(0, 2, | |||
| 1).reshape( | |||
| b, c, 1, 1) # bs,c,1,1 | |||
| channel_out = channel_weight * x | |||
| # Spatial-only Self-Attention | |||
| spatial_wv = self.sp_wv(channel_out) # bs,c//2,h,w | |||
| spatial_wq = self.sp_wq(channel_out) # bs,c//2,h,w | |||
| spatial_wq = self.agp(spatial_wq) # bs,c//2,1,1 | |||
| spatial_wv = spatial_wv.reshape(b, c // 2, -1) # bs,c//2,h*w | |||
| spatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2) # bs,1,c//2 | |||
| spatial_wq = self.softmax_spatial(spatial_wq) | |||
| spatial_wz = ms.ops.matmul(spatial_wq, spatial_wv) # bs,1,h*w | |||
| spatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w)) # bs,1,h,w | |||
| spatial_out = spatial_weight * channel_out | |||
| return spatial_out | |||
| if __name__ == '__main__': | |||
| dummy_input = ms.ops.randn((1, 512, 7, 7)) | |||
| psa = SequentialPolarizedSelfAttention(channel=512) | |||
| output = psa(dummy_input) | |||
| print(output.shape) | |||