Browse Source

Delete 'model/features.py'

v1
limingjuan 2 years ago
parent
commit
96e5a7825e
1 changed files with 0 additions and 136 deletions
  1. +0
    -136
      model/features.py

+ 0
- 136
model/features.py View File

@@ -1,136 +0,0 @@
""" features """
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Tuple

from mindspore import nn
from mindspore import Tensor


def _cell_list(net: nn.Cell, flatten_sequential: bool = False) -> Iterable[Tuple[str, str, nn.Cell]]:
"""Yield the partially flattened cell list from the model, together with its new name and old name

Args:
net (nn.Cell): Network need to be partially flattened
flatten_sequential (bool): Flatten the inner-layer of the sequential cell. Default: False.

Returns:
iterator[tuple[str, str, nn.Cell]]: The new name, the old name and corresponding cell
"""
for name, cell in net.name_cells().items():
if flatten_sequential and isinstance(cell, nn.SequentialCell):
for child_name, child_cell in cell.name_cells().items():
combined = [name, child_name]
yield "_".join(combined), ".".join(combined), child_cell
else:
yield name, name, cell


def _get_return_layers(feature_info: Dict[str, Any], out_indices: List[int]) -> Dict[str, int]:
"""Create a dict storing the "layer_name - layer_id" pair that need to be extracted"""
return_layers = {}
for i, x in enumerate(feature_info):
if i in out_indices:
return_layers[x["name"]] = i
return return_layers


class FeatureExtractWrapper(nn.Cell):
"""A wrapper of the original model, aims to perform the feature extraction at each stride.
Basically, it performs 3 steps: 1. extract the return node name from the network's property
`feature_info`; 2. partially flatten the network architecture if network's attribute `flatten_sequential`
is True; 3. rebuild the forward steps and output the features based on the return node name.

It also provide a property `out_channels` in the wrapped model, return the number of features at each output
layer. This propery is usually used for the downstream tasks, which requires feature infomation at network
build stage.

It should be note that to apply this wrapper, there is a strong assumption that each of the outmost cell
are registered in the same order as they are used. And there should be no reuse of each cell, even for the `ReLU`
cell. Otherwise, the returned result may not be correct.

And it should be also note that it basically rebuild the model. So the default checkpoint parameter cannot be loaded
correctly once that model is wrapped. To use the pretrained weight, please load the weight first and then use this
wrapper to rebuild the model.

Args:
net (nn.Cell): The model need to be wrapped.
out_indices (list[int]): The indicies of the output features. Default: [0, 1, 2, 3, 4]
"""

def __init__(self, net: nn.Cell, out_indices=None) -> None:
super().__init__(auto_prefix=False)

if out_indices is None:
out_indices = [0, 1, 2, 3, 4]
feature_info = self._get_feature_info(net)
self.is_rewritten = getattr(net, "is_rewritten", False)
flatten_sequetial = getattr(net, "flatten_sequential", False)
return_layers = _get_return_layers(feature_info, out_indices)
self.return_index = []

if not self.is_rewritten:
cells = _cell_list(net, flatten_sequential=flatten_sequetial)
self.net, updated_return_layers = self._create_net(cells, return_layers)

# calculate the return index
for i, name in enumerate(self.net.name_cells().keys()):
if name in updated_return_layers:
self.return_index.append(i)
else:
self.net = net
self.return_index = out_indices

# calculate the out_channels
self._out_channels = []
for i in return_layers.values():
self._out_channels.append(feature_info[i]["chs"])

@property
def out_channels(self):
"""The output channels of the model, filtered by the out_indices.
"""
return self._out_channels

def construct(self, x: Tensor) -> List[Tensor]:
return self._collect(x)

def _get_feature_info(self, net: nn.Cell) -> Dict[str, Any]:
try:
feature_info = getattr(net, "feature_info")
except AttributeError as exc:
raise AttributeError from exc
return feature_info

def _create_net(
self, cells: Iterable[Tuple[str, str, nn.Cell]], return_layers: Dict[str, int]
) -> Tuple[nn.SequentialCell, Dict[str, int]]:
layers = OrderedDict()
updated_return_layers = {}
remaining = set(return_layers.keys())
for new_name, old_name, module in cells:
layers[new_name] = module
if old_name in remaining:
updated_return_layers[new_name] = return_layers[old_name]
remaining.remove(old_name)
if not remaining:
break

net = nn.SequentialCell(layers)
return net, updated_return_layers

def _collect(self, x: Tensor) -> List[Tensor]:
out = []

if self.is_rewritten:
xs = self.net(x)

for i, s in enumerate(xs):
if i in self.return_index:
out.append(s)
else:
for i, cell in enumerate(self.net.cell_list):
x = cell(x)
if i in self.return_index:
out.append(x)

return out

Loading…
Cancel
Save