|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- """ features """
- from collections import OrderedDict
- from typing import Any, Dict, Iterable, List, Tuple
-
- from mindspore import nn
- from mindspore import Tensor
-
-
- def _cell_list(net: nn.Cell, flatten_sequential: bool = False) -> Iterable[Tuple[str, str, nn.Cell]]:
- """Yield the partially flattened cell list from the model, together with its new name and old name
-
- Args:
- net (nn.Cell): Network need to be partially flattened
- flatten_sequential (bool): Flatten the inner-layer of the sequential cell. Default: False.
-
- Returns:
- iterator[tuple[str, str, nn.Cell]]: The new name, the old name and corresponding cell
- """
- for name, cell in net.name_cells().items():
- if flatten_sequential and isinstance(cell, nn.SequentialCell):
- for child_name, child_cell in cell.name_cells().items():
- combined = [name, child_name]
- yield "_".join(combined), ".".join(combined), child_cell
- else:
- yield name, name, cell
-
-
- def _get_return_layers(feature_info: Dict[str, Any], out_indices: List[int]) -> Dict[str, int]:
- """Create a dict storing the "layer_name - layer_id" pair that need to be extracted"""
- return_layers = {}
- for i, x in enumerate(feature_info):
- if i in out_indices:
- return_layers[x["name"]] = i
- return return_layers
-
-
- class FeatureExtractWrapper(nn.Cell):
- """A wrapper of the original model, aims to perform the feature extraction at each stride.
- Basically, it performs 3 steps: 1. extract the return node name from the network's property
- `feature_info`; 2. partially flatten the network architecture if network's attribute `flatten_sequential`
- is True; 3. rebuild the forward steps and output the features based on the return node name.
-
- It also provide a property `out_channels` in the wrapped model, return the number of features at each output
- layer. This propery is usually used for the downstream tasks, which requires feature infomation at network
- build stage.
-
- It should be note that to apply this wrapper, there is a strong assumption that each of the outmost cell
- are registered in the same order as they are used. And there should be no reuse of each cell, even for the `ReLU`
- cell. Otherwise, the returned result may not be correct.
-
- And it should be also note that it basically rebuild the model. So the default checkpoint parameter cannot be loaded
- correctly once that model is wrapped. To use the pretrained weight, please load the weight first and then use this
- wrapper to rebuild the model.
-
- Args:
- net (nn.Cell): The model need to be wrapped.
- out_indices (list[int]): The indicies of the output features. Default: [0, 1, 2, 3, 4]
- """
-
- def __init__(self, net: nn.Cell, out_indices=None) -> None:
- super().__init__(auto_prefix=False)
-
- if out_indices is None:
- out_indices = [0, 1, 2, 3, 4]
- feature_info = self._get_feature_info(net)
- self.is_rewritten = getattr(net, "is_rewritten", False)
- flatten_sequetial = getattr(net, "flatten_sequential", False)
- return_layers = _get_return_layers(feature_info, out_indices)
- self.return_index = []
-
- if not self.is_rewritten:
- cells = _cell_list(net, flatten_sequential=flatten_sequetial)
- self.net, updated_return_layers = self._create_net(cells, return_layers)
-
- # calculate the return index
- for i, name in enumerate(self.net.name_cells().keys()):
- if name in updated_return_layers:
- self.return_index.append(i)
- else:
- self.net = net
- self.return_index = out_indices
-
- # calculate the out_channels
- self._out_channels = []
- for i in return_layers.values():
- self._out_channels.append(feature_info[i]["chs"])
-
- @property
- def out_channels(self):
- """The output channels of the model, filtered by the out_indices.
- """
- return self._out_channels
-
- def construct(self, x: Tensor) -> List[Tensor]:
- return self._collect(x)
-
- def _get_feature_info(self, net: nn.Cell) -> Dict[str, Any]:
- try:
- feature_info = getattr(net, "feature_info")
- except AttributeError as exc:
- raise AttributeError from exc
- return feature_info
-
- def _create_net(
- self, cells: Iterable[Tuple[str, str, nn.Cell]], return_layers: Dict[str, int]
- ) -> Tuple[nn.SequentialCell, Dict[str, int]]:
- layers = OrderedDict()
- updated_return_layers = {}
- remaining = set(return_layers.keys())
- for new_name, old_name, module in cells:
- layers[new_name] = module
- if old_name in remaining:
- updated_return_layers[new_name] = return_layers[old_name]
- remaining.remove(old_name)
- if not remaining:
- break
-
- net = nn.SequentialCell(layers)
- return net, updated_return_layers
-
- def _collect(self, x: Tensor) -> List[Tensor]:
- out = []
-
- if self.is_rewritten:
- xs = self.net(x)
-
- for i, s in enumerate(xs):
- if i in self.return_index:
- out.append(s)
- else:
- for i, cell in enumerate(self.net.cell_list):
- x = cell(x)
- if i in self.return_index:
- out.append(x)
-
- return out
|