You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BAM.py 3.6 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. MindSpore implementation of 'BAM'
  3. Refer to "BAM: Bottleneck Attention Module"
  4. """
  5. import mindspore as ms
  6. from mindspore import nn
  7. from mindspore.common.initializer import initializer, HeNormal, Normal
  8. class Flatten(nn.Cell):
  9. """ Flatten """
  10. def construct(self, x):
  11. return x.view(x.shape[0], -1)
  12. class ChannelGate(nn.Cell):
  13. """ Channel Attention """
  14. def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
  15. super().__init__()
  16. # self.gate_activation = gate_channel
  17. self.gate_c = nn.SequentialCell()
  18. self.gate_c.append(Flatten())
  19. gate_channels = [gate_channel]
  20. gate_channels += [gate_channel // reduction_ratio] * num_layers
  21. gate_channels += [gate_channel]
  22. for i in range(len(gate_channels) - 2):
  23. self.gate_c.append(nn.Dense(gate_channels[i], gate_channels[i + 1]))
  24. self.gate_c.append(nn.BatchNorm1d(gate_channels[i + 1]))
  25. self.gate_c.append(nn.ReLU())
  26. self.gate_c.append(nn.Dense(gate_channels[-2], gate_channels[-1]))
  27. def construct(self, in_tensor):
  28. avg_pool = ms.ops.avg_pool2d(in_tensor, in_tensor.shape[2], stride=in_tensor.shape[2])
  29. out = self.gate_c(avg_pool).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)
  30. return out
  31. class SpatialGate(nn.Cell):
  32. """ Spatial Attention """
  33. def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
  34. super().__init__()
  35. self.gate_s = nn.SequentialCell()
  36. self.gate_s.append(nn.Conv2d(gate_channel, gate_channel // reduction_ratio, kernel_size=1))
  37. self.gate_s.append(nn.BatchNorm2d(gate_channel // reduction_ratio))
  38. self.gate_s.append(nn.ReLU())
  39. for _ in range(dilation_conv_num):
  40. self.gate_s.append(nn.Conv2d(gate_channel // reduction_ratio, gate_channel // reduction_ratio,
  41. kernel_size=3, pad_mode='pad',
  42. padding=dilation_val, dilation=dilation_val))
  43. self.gate_s.append(nn.BatchNorm2d(gate_channel // reduction_ratio))
  44. self.gate_s.append(nn.ReLU())
  45. self.gate_s.append(nn.Conv2d(gate_channel // reduction_ratio, 1, kernel_size=1))
  46. def construct(self, in_tensor):
  47. return self.gate_s(in_tensor).expand_as(in_tensor)
  48. class BAMBlock(nn.Cell):
  49. """ BAM """
  50. def __init__(self, gate_channel):
  51. super().__init__()
  52. self.channel_att = ChannelGate(gate_channel)
  53. self.spatial_att = SpatialGate(gate_channel)
  54. self.sigmoid = nn.Sigmoid()
  55. self.apply(self._init_weights)
  56. def _init_weights(self, cell):
  57. if isinstance(cell, nn.Conv2d):
  58. cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype))
  59. if cell.bias is not None:
  60. cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
  61. elif isinstance(cell, nn.Dense):
  62. cell.weight.set_data(initializer(Normal(sigma=0.001), cell.weight.shape, cell.weight.dtype))
  63. if cell.bias is not None:
  64. cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
  65. def construct(self, x):
  66. sa_out = self.spatial_att(x)
  67. ca_out = self.channel_att(x)
  68. weight = self.sigmoid(sa_out + ca_out)
  69. out = (1 + weight) * x
  70. return out
  71. if __name__ == "__main__":
  72. dummy_input = ms.ops.randn(12, 128, 14, 14)
  73. bam = BAMBlock(128)
  74. output = bam(dummy_input)
  75. print(output.shape)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN