You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Involution.py 1.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """ involution """
  2. import mindspore as ms
  3. from mindspore import nn
  4. class Involution(nn.Cell):
  5. """ Involution
  6. """
  7. def __init__(self, kernel_size, in_channels=4, stride=1, group=1, ratio=4):
  8. super().__init__()
  9. self.kernel_size = kernel_size
  10. self.in_channels = in_channels
  11. self.stride = stride
  12. self.group = group
  13. assert self.in_channels % group == 0
  14. self.group_channel = self.in_channels // group
  15. self.conv1 = nn.Conv2d(self.in_channels, self.in_channels//ratio, kernel_size=1)
  16. self.bn = nn.BatchNorm2d(in_channels // ratio)
  17. self.relu = nn.ReLU()
  18. self.conv2 = nn.Conv2d(self.in_channels // ratio,
  19. self.group * self.kernel_size * self.kernel_size, kernel_size=1)
  20. self.avgpool = nn.AvgPool2d(stride, stride) if stride > 1 else nn.Identity()
  21. def construct(self, x):
  22. B, C, H, W = x.shape
  23. weight = self.conv2(self.relu(self.bn(self.conv1(self.avgpool(x)))))
  24. b, _, h, w = weight.shape
  25. weight = weight.reshape(b, self.group, self.kernel_size*self.kernel_size, h, w).unsqueeze(2)
  26. x_unfold = ms.ops.unfold(x, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=self.stride)
  27. x_unfold = x_unfold.reshape(B, self.group, C // self.group,
  28. self.kernel_size * self.kernel_size, H // self.stride, W // self.stride)
  29. out = (x_unfold * weight).sum(axis=3)
  30. out = out.reshape(B, C, H // self.stride, W // self.stride)
  31. return out
  32. if __name__ == "__main__":
  33. in_tensor = ms.ops.randn(1, 4, 64, 64)
  34. involution = Involution(3, in_channels=4, stride=1)
  35. output = involution(in_tensor)
  36. print(output.shape)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN