|
|
|
@@ -0,0 +1,61 @@ |
|
|
|
""" |
|
|
|
MindSpore implementation of 'DANet' |
|
|
|
Refer to "Dual Attention Network for Scene Segmentation" |
|
|
|
""" |
|
|
|
# pylint: disable=E0401 |
|
|
|
import mindspore as ms |
|
|
|
from mindspore import nn |
|
|
|
from SelfAttention import SelfAttention |
|
|
|
from SimplifiedSelfAttention import SimplifiedScaledDotProductAttention |
|
|
|
|
|
|
|
|
|
|
|
class PositionAttentionCell(nn.Cell): |
|
|
|
""" Position Attention """ |
|
|
|
def __init__(self, d_model=512, kernel_size=3): |
|
|
|
super().__init__() |
|
|
|
self.cnn = nn.Conv2d(d_model, d_model, kernel_size=kernel_size, pad_mode='pad', padding=(kernel_size - 1) // 2) |
|
|
|
self.pa = SelfAttention(d_model, d_k=d_model, d_v=d_model, h=1) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
B, C, _, _ = x.shape |
|
|
|
y = self.cnn(x) |
|
|
|
y = y.view(B, C, -1).permute(0, 2, 1) |
|
|
|
y = self.pa(y) |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttentionCell(nn.Cell): |
|
|
|
""" Channel Attention """ |
|
|
|
def __init__(self, d_model=512, kernel_size=3, H=7, W=7): |
|
|
|
super().__init__() |
|
|
|
self.cnn = nn.Conv2d(d_model, d_model, kernel_size=kernel_size, pad_mode='pad', padding=(kernel_size - 1) // 2) |
|
|
|
self.pa = SimplifiedScaledDotProductAttention(H * W, h=1) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
B, C, _, _ = x.shape |
|
|
|
y = self.cnn(x) |
|
|
|
y = y.view(B, C, -1) |
|
|
|
y = self.pa(y) |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class DACell(nn.Cell): |
|
|
|
""" DANet """ |
|
|
|
def __init__(self, d_model=512, kernel_size=3, H=7, W=7): |
|
|
|
super().__init__() |
|
|
|
self.position_attention = PositionAttentionCell(d_model=d_model, kernel_size=kernel_size) |
|
|
|
self.channel_attention = ChannelAttentionCell(d_model=d_model, kernel_size=kernel_size, H=H, W=W) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
B, C, H, W = x.shape |
|
|
|
p_out = self.position_attention(x) |
|
|
|
c_out = self.channel_attention(x) |
|
|
|
p_out = p_out.permute(0, 2, 1).view(B, C, H, W) |
|
|
|
c_out = c_out.view(B, C, H, W) |
|
|
|
return p_out + c_out |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
dummy_input = ms.ops.randn((50, 512, 7, 7)) |
|
|
|
model = DACell(d_model=512, kernel_size=3, H=7, W=7) |
|
|
|
print(model(dummy_input).shape) |