|
- """ involution """
- import mindspore as ms
- from mindspore import nn
-
-
- class Involution(nn.Cell):
- """ Involution
- """
- def __init__(self, kernel_size, in_channels=4, stride=1, group=1, ratio=4):
- super().__init__()
- self.kernel_size = kernel_size
- self.in_channels = in_channels
- self.stride = stride
- self.group = group
-
- assert self.in_channels % group == 0
-
- self.group_channel = self.in_channels // group
- self.conv1 = nn.Conv2d(self.in_channels, self.in_channels//ratio, kernel_size=1)
- self.bn = nn.BatchNorm2d(in_channels // ratio)
- self.relu = nn.ReLU()
- self.conv2 = nn.Conv2d(self.in_channels // ratio,
- self.group * self.kernel_size * self.kernel_size, kernel_size=1)
- self.avgpool = nn.AvgPool2d(stride, stride) if stride > 1 else nn.Identity()
-
- def construct(self, x):
- B, C, H, W = x.shape
- weight = self.conv2(self.relu(self.bn(self.conv1(self.avgpool(x)))))
- b, _, h, w = weight.shape
- weight = weight.reshape(b, self.group, self.kernel_size*self.kernel_size, h, w).unsqueeze(2)
-
- x_unfold = ms.ops.unfold(x, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=self.stride)
- x_unfold = x_unfold.reshape(B, self.group, C // self.group,
- self.kernel_size * self.kernel_size, H // self.stride, W // self.stride)
-
- out = (x_unfold * weight).sum(axis=3)
- out = out.reshape(B, C, H // self.stride, W // self.stride)
- return out
-
-
- if __name__ == "__main__":
- in_tensor = ms.ops.randn(1, 4, 64, 64)
- involution = Involution(3, in_channels=4, stride=1)
- output = involution(in_tensor)
- print(output.shape)
|