Browse Source

ParNetAttention

pull/3/head
Huyf9 2 years ago
parent
commit
c8be95322f
1 changed files with 40 additions and 0 deletions
  1. +40
    -0
      model/attention/ParNetAttention.py

+ 40
- 0
model/attention/ParNetAttention.py View File

@@ -0,0 +1,40 @@
"""
MindSpore implementation of 'ParNetAttention'
Refer to "Non-deep Networks"
"""
import mindspore as ms
from mindspore import nn


class ParNetAttention(nn.Cell):
""" ParNetAttention """
def __init__(self, channels=512):
super().__init__()
self.sse = nn.SequentialCell(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels, kernel_size=1),
nn.Sigmoid()
)
self.conv1x1 = nn.SequentialCell(
nn.Conv2d(channels, channels, kernel_size=1),
nn.BatchNorm2d(channels)
)
self.conv3x3 = nn.SequentialCell(
nn.Conv2d(channels, channels, kernel_size=3, pad_mode='pad', padding=1),
nn.BatchNorm2d(channels)
)
self.silu = nn.SiLU()

def construct(self, x):
x1 = self.conv1x1(x)
x2 = self.conv3x3(x)
x3 = self.sse(x) * x
out = self.silu(x1 + x2 + x3)
return out


if __name__ == "__main__":
dummy_input = ms.ops.randn((50, 512, 7, 7))
pna = ParNetAttention(channels=512)
output = pna(dummy_input)
print(output.shape)

Loading…
Cancel
Save