|
|
|
@@ -0,0 +1,40 @@ |
|
|
|
""" |
|
|
|
MindSpore implementation of 'ParNetAttention' |
|
|
|
Refer to "Non-deep Networks" |
|
|
|
""" |
|
|
|
import mindspore as ms |
|
|
|
from mindspore import nn |
|
|
|
|
|
|
|
|
|
|
|
class ParNetAttention(nn.Cell): |
|
|
|
""" ParNetAttention """ |
|
|
|
def __init__(self, channels=512): |
|
|
|
super().__init__() |
|
|
|
self.sse = nn.SequentialCell( |
|
|
|
nn.AdaptiveAvgPool2d(1), |
|
|
|
nn.Conv2d(channels, channels, kernel_size=1), |
|
|
|
nn.Sigmoid() |
|
|
|
) |
|
|
|
self.conv1x1 = nn.SequentialCell( |
|
|
|
nn.Conv2d(channels, channels, kernel_size=1), |
|
|
|
nn.BatchNorm2d(channels) |
|
|
|
) |
|
|
|
self.conv3x3 = nn.SequentialCell( |
|
|
|
nn.Conv2d(channels, channels, kernel_size=3, pad_mode='pad', padding=1), |
|
|
|
nn.BatchNorm2d(channels) |
|
|
|
) |
|
|
|
self.silu = nn.SiLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x1 = self.conv1x1(x) |
|
|
|
x2 = self.conv3x3(x) |
|
|
|
x3 = self.sse(x) * x |
|
|
|
out = self.silu(x1 + x2 + x3) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dummy_input = ms.ops.randn((50, 512, 7, 7)) |
|
|
|
pna = ParNetAttention(channels=512) |
|
|
|
output = pna(dummy_input) |
|
|
|
print(output.shape) |