You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondConv.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """ CondConv """
  2. import mindspore as ms
  3. from mindspore import nn
  4. from mindspore.common.initializer import initializer, HeNormal, HeUniform
  5. class Attention(nn.Cell):
  6. """ Attnetion """
  7. def __init__(self, in_channels, K, init_weight=True):
  8. super().__init__()
  9. self.avgpool = nn.AdaptiveAvgPool2d(1)
  10. self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1)
  11. self.sigmoid = nn.Sigmoid()
  12. if init_weight:
  13. self.apply(self.init_weights)
  14. def init_weights(self, cell):
  15. """ initialize weights """
  16. if isinstance(cell, nn.Conv2d):
  17. cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'),
  18. cell.weight.shape, cell.weight.dtype))
  19. if cell.bias is not None:
  20. cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
  21. elif isinstance(cell, nn.BatchNorm2d):
  22. cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype))
  23. cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))
  24. def construct(self, x):
  25. att = self.avgpool(x)
  26. att = self.net(att).view(x.shape[0], -1)
  27. return self.sigmoid(att)
  28. class CondConv(nn.Cell):
  29. """ CondConv """
  30. def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0,
  31. dilation=1, groups=1, bias=True, K=4, init_weight=True):
  32. super().__init__()
  33. self.in_channels = in_channels
  34. self.out_channels = out_channels
  35. self.kernel_size = kernel_size
  36. self.stride = stride
  37. self.padding = padding
  38. self.dilation = dilation
  39. self.groups = groups
  40. self.bias = bias
  41. self.K = K
  42. self.attention = Attention(in_channels, K)
  43. self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size),
  44. requires_grad=True)
  45. if bias:
  46. self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True)
  47. else:
  48. self.bias = None
  49. if init_weight:
  50. self.init_weights()
  51. def init_weights(self):
  52. """ initialize weights """
  53. for i in range(self.K):
  54. self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype)
  55. def construct(self, x):
  56. B, _, H, W = x.shape
  57. softmax_att = self.attention(x)
  58. x = x.view(1, -1, H, W)
  59. weight = self.weight.view(self.K, -1)
  60. aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups,
  61. self.kernel_size, self.kernel_size)
  62. if self.bias:
  63. bias = self.bias.view(self.K, -1)
  64. aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1)
  65. output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad",
  66. padding=self.padding, dilation=self.dilation, groups=self.groups * B)
  67. else:
  68. output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad",
  69. padding=self.padding, dilation=self.dilation, groups=self.groups * B)
  70. output = output.view(B, self.out_channels, H, W)
  71. return output
  72. if __name__ == "__main__":
  73. in_tensor = ms.ops.randn((2, 32, 64, 64))
  74. cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
  75. out = cconv(in_tensor)
  76. print(out.shape)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN