You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

features.py 5.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """ features """
  2. from collections import OrderedDict
  3. from typing import Any, Dict, Iterable, List, Tuple
  4. from mindspore import nn
  5. from mindspore import Tensor
  6. def _cell_list(net: nn.Cell, flatten_sequential: bool = False) -> Iterable[Tuple[str, str, nn.Cell]]:
  7. """Yield the partially flattened cell list from the model, together with its new name and old name
  8. Args:
  9. net (nn.Cell): Network need to be partially flattened
  10. flatten_sequential (bool): Flatten the inner-layer of the sequential cell. Default: False.
  11. Returns:
  12. iterator[tuple[str, str, nn.Cell]]: The new name, the old name and corresponding cell
  13. """
  14. for name, cell in net.name_cells().items():
  15. if flatten_sequential and isinstance(cell, nn.SequentialCell):
  16. for child_name, child_cell in cell.name_cells().items():
  17. combined = [name, child_name]
  18. yield "_".join(combined), ".".join(combined), child_cell
  19. else:
  20. yield name, name, cell
  21. def _get_return_layers(feature_info: Dict[str, Any], out_indices: List[int]) -> Dict[str, int]:
  22. """Create a dict storing the "layer_name - layer_id" pair that need to be extracted"""
  23. return_layers = {}
  24. for i, x in enumerate(feature_info):
  25. if i in out_indices:
  26. return_layers[x["name"]] = i
  27. return return_layers
  28. class FeatureExtractWrapper(nn.Cell):
  29. """A wrapper of the original model, aims to perform the feature extraction at each stride.
  30. Basically, it performs 3 steps: 1. extract the return node name from the network's property
  31. `feature_info`; 2. partially flatten the network architecture if network's attribute `flatten_sequential`
  32. is True; 3. rebuild the forward steps and output the features based on the return node name.
  33. It also provide a property `out_channels` in the wrapped model, return the number of features at each output
  34. layer. This propery is usually used for the downstream tasks, which requires feature infomation at network
  35. build stage.
  36. It should be note that to apply this wrapper, there is a strong assumption that each of the outmost cell
  37. are registered in the same order as they are used. And there should be no reuse of each cell, even for the `ReLU`
  38. cell. Otherwise, the returned result may not be correct.
  39. And it should be also note that it basically rebuild the model. So the default checkpoint parameter cannot be loaded
  40. correctly once that model is wrapped. To use the pretrained weight, please load the weight first and then use this
  41. wrapper to rebuild the model.
  42. Args:
  43. net (nn.Cell): The model need to be wrapped.
  44. out_indices (list[int]): The indicies of the output features. Default: [0, 1, 2, 3, 4]
  45. """
  46. def __init__(self, net: nn.Cell, out_indices=None) -> None:
  47. super().__init__(auto_prefix=False)
  48. if out_indices is None:
  49. out_indices = [0, 1, 2, 3, 4]
  50. feature_info = self._get_feature_info(net)
  51. self.is_rewritten = getattr(net, "is_rewritten", False)
  52. flatten_sequetial = getattr(net, "flatten_sequential", False)
  53. return_layers = _get_return_layers(feature_info, out_indices)
  54. self.return_index = []
  55. if not self.is_rewritten:
  56. cells = _cell_list(net, flatten_sequential=flatten_sequetial)
  57. self.net, updated_return_layers = self._create_net(cells, return_layers)
  58. # calculate the return index
  59. for i, name in enumerate(self.net.name_cells().keys()):
  60. if name in updated_return_layers:
  61. self.return_index.append(i)
  62. else:
  63. self.net = net
  64. self.return_index = out_indices
  65. # calculate the out_channels
  66. self._out_channels = []
  67. for i in return_layers.values():
  68. self._out_channels.append(feature_info[i]["chs"])
  69. @property
  70. def out_channels(self):
  71. """The output channels of the model, filtered by the out_indices.
  72. """
  73. return self._out_channels
  74. def construct(self, x: Tensor) -> List[Tensor]:
  75. return self._collect(x)
  76. def _get_feature_info(self, net: nn.Cell) -> Dict[str, Any]:
  77. try:
  78. feature_info = getattr(net, "feature_info")
  79. except AttributeError as exc:
  80. raise AttributeError from exc
  81. return feature_info
  82. def _create_net(
  83. self, cells: Iterable[Tuple[str, str, nn.Cell]], return_layers: Dict[str, int]
  84. ) -> Tuple[nn.SequentialCell, Dict[str, int]]:
  85. layers = OrderedDict()
  86. updated_return_layers = {}
  87. remaining = set(return_layers.keys())
  88. for new_name, old_name, module in cells:
  89. layers[new_name] = module
  90. if old_name in remaining:
  91. updated_return_layers[new_name] = return_layers[old_name]
  92. remaining.remove(old_name)
  93. if not remaining:
  94. break
  95. net = nn.SequentialCell(layers)
  96. return net, updated_return_layers
  97. def _collect(self, x: Tensor) -> List[Tensor]:
  98. out = []
  99. if self.is_rewritten:
  100. xs = self.net(x)
  101. for i, s in enumerate(xs):
  102. if i in self.return_index:
  103. out.append(s)
  104. else:
  105. for i, cell in enumerate(self.net.cell_list):
  106. x = cell(x)
  107. if i in self.return_index:
  108. out.append(x)
  109. return out

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN