| @@ -84,10 +84,11 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||||
| """ | """ | ||||
| from .third_party_graph import GraphFactory | from .third_party_graph import GraphFactory | ||||
| from .hierarchical_tree import HierarchicalTreeFactory | |||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | ||||
| checkpoint=checkpoint_path) | checkpoint=checkpoint_path) | ||||
| hierarchical_tree = graph_obj.to_hierarchical_tree() | |||||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | ||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| @@ -16,5 +16,42 @@ | |||||
| from .hierarchical_tree import HierarchicalTree | from .hierarchical_tree import HierarchicalTree | ||||
| __all__ = [ | __all__ = [ | ||||
| "HierarchicalTree" | |||||
| "HierarchicalTreeFactory" | |||||
| ] | ] | ||||
| class HierarchicalTreeFactory: | |||||
| """Hierarchical tree factory.""" | |||||
| @classmethod | |||||
| def create(cls, graph): | |||||
| """ | |||||
| Factory method of hierarchical tree. | |||||
| Args: | |||||
| graph: Graph obj. | |||||
| Returns: | |||||
| HierarchicalTree, tree. | |||||
| """ | |||||
| tree = HierarchicalTree() | |||||
| node_input = None | |||||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||||
| node_inst = graph.get_node(node_name) | |||||
| node_output = graph.get_output_shape(node_name) | |||||
| if node_inst.in_degree == 0: | |||||
| # If in-degree equals to zero, then it's a input node. | |||||
| continue | |||||
| # If the node is on the top, then fetch its input | |||||
| # from input table. | |||||
| if not node_input: | |||||
| node_input = graph.get_input_shape(node_name) | |||||
| if not node_input: | |||||
| raise ValueError(f"This model is not supported now. " | |||||
| f"Cannot find {node_name}'s input shape.") | |||||
| tree.insert(node_inst, node_name, node_input, node_output) | |||||
| node_input = node_output | |||||
| return tree | |||||
| @@ -53,10 +53,6 @@ class ModuleNameMgr(NameMgr): | |||||
| """Module name manager.""" | """Module name manager.""" | ||||
| class VariableNameMgrInModule(NameMgr): | |||||
| """Variable name mgr for a module.""" | |||||
| global_op_namespace = dict() | global_op_namespace = dict() | ||||
| START_IDX = 0 | START_IDX = 0 | ||||
| @@ -15,9 +15,6 @@ | |||||
| """Define graph entity.""" | """Define graph entity.""" | ||||
| import abc | import abc | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from typing import Dict, Union, Any | |||||
| from torch.nn import Module | |||||
| from ..constant import SEPARATOR_IN_ONNX_OP | from ..constant import SEPARATOR_IN_ONNX_OP | ||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| @@ -40,21 +37,13 @@ class BaseGraph(metaclass=abc.ABCMeta): | |||||
| def build(self, input_shape: tuple): | def build(self, input_shape: tuple): | ||||
| """Build graph.""" | """Build graph.""" | ||||
| @abc.abstractmethod | |||||
| def to_ir(self, mapper): | |||||
| """Convert graph to ir graph.""" | |||||
| @abc.abstractmethod | |||||
| def to_hierarchical_tree(self): | |||||
| """Convert to hierarchical tree.""" | |||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def sub_graph_merging(self): | def sub_graph_merging(self): | ||||
| """Merge split nodes into one.""" | """Merge split nodes into one.""" | ||||
| @staticmethod | @staticmethod | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||||
| def load_checkpoint(ckpt_path: str) -> dict: | |||||
| """Load checkpoint file.""" | """Load checkpoint file.""" | ||||
| @staticmethod | @staticmethod | ||||
| @@ -88,15 +77,14 @@ class Graph(BaseGraph, abc.ABC): | |||||
| Define Factory method to create Graph sub-class. | Define Factory method to create Graph sub-class. | ||||
| Args: | Args: | ||||
| model (Union[Module, Any]): Graph file. | |||||
| model (Union[torch.nn.Module, Any]): Graph file. | |||||
| checkpoint (dict): Checkpoint path. | checkpoint (dict): Checkpoint path. | ||||
| """ | """ | ||||
| sorted = False | sorted = False | ||||
| def __init__(self, model: Union[Module, Any], | |||||
| **kwargs): | |||||
| def __init__(self, model, **kwargs): | |||||
| super(Graph, self).__init__() | super(Graph, self).__init__() | ||||
| self.model = model | self.model = model | ||||
| self.checkpoint = kwargs.get("checkpoint", None) | self.checkpoint = kwargs.get("checkpoint", None) | ||||
| @@ -108,6 +96,27 @@ class Graph(BaseGraph, abc.ABC): | |||||
| self._topological_order = [] | self._topological_order = [] | ||||
| self._input_shape = dict() | self._input_shape = dict() | ||||
| def get_output_shape(self, name): | |||||
| """ | |||||
| Get node output shape. | |||||
| Args: | |||||
| name (str): Node name. | |||||
| Returns: | |||||
| list, shape. | |||||
| """ | |||||
| return self._shape_dict.get(name) | |||||
| def get_input_shape(self, name): | |||||
| """ | |||||
| Get node input shape. | |||||
| Returns: | |||||
| list, shape. | |||||
| """ | |||||
| return self._input_shape.get(name) | |||||
| @property | @property | ||||
| def nodes_in_topological_order(self): | def nodes_in_topological_order(self): | ||||
| """ | """ | ||||
| @@ -192,17 +201,11 @@ class Graph(BaseGraph, abc.ABC): | |||||
| idx += 1 | idx += 1 | ||||
| self.sorted = True | self.sorted = True | ||||
| def to_ir(self, mapper): | |||||
| raise NotImplementedError | |||||
| def to_hierarchical_tree(self): | |||||
| raise NotImplementedError | |||||
| def sub_graph_merging(self): | def sub_graph_merging(self): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @staticmethod | @staticmethod | ||||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||||
| def load_checkpoint(ckpt_path: str) -> dict: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @staticmethod | @staticmethod | ||||
| @@ -17,18 +17,11 @@ import warnings | |||||
| import re | import re | ||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| import torch | |||||
| from torch.nn import Module | |||||
| from torch.onnx import OperatorExportTypes | |||||
| from .base import Graph | from .base import Graph | ||||
| from .input_node import InputNode | from .input_node import InputNode | ||||
| from .pytorch_graph_node import PyTorchGraphNode | from .pytorch_graph_node import PyTorchGraphNode | ||||
| from .graph_parser import PyTorchGraphParser | from .graph_parser import PyTorchGraphParser | ||||
| from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict | |||||
| from .torch_utils import create_autograd_variable | |||||
| from .torch_utils import onnx_tracer | |||||
| from ..hierarchical_tree import HierarchicalTree | |||||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | ||||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | from ..constant import LEFT_BUCKET, RIGHT_BUCKET | ||||
| @@ -78,8 +71,11 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| def __init__(self, model: Module, sample_shape: tuple): | |||||
| def __init__(self, model, sample_shape: tuple): | |||||
| super(PyTorchGraph, self).__init__(model=model) | super(PyTorchGraph, self).__init__(model=model) | ||||
| from .torch_utils import unique_state_dict | |||||
| self._params_dict = unique_state_dict(model) | self._params_dict = unique_state_dict(model) | ||||
| self.build(sample_shape) | self.build(sample_shape) | ||||
| @@ -108,6 +104,12 @@ class PyTorchGraph(Graph): | |||||
| input_shape (tuple): Input shape of model. | input_shape (tuple): Input shape of model. | ||||
| """ | """ | ||||
| import torch | |||||
| from torch.onnx import OperatorExportTypes | |||||
| from .torch_utils import OverloadTorchModuleTemporarily | |||||
| from .torch_utils import create_autograd_variable | |||||
| from .torch_utils import onnx_tracer | |||||
| self._check_input_shape(input_shape) | self._check_input_shape(input_shape) | ||||
| def _extract_shape(shape): | def _extract_shape(shape): | ||||
| @@ -188,32 +190,6 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def to_hierarchical_tree(self): | |||||
| """ | |||||
| Generate hierarchical tree based on graph. | |||||
| """ | |||||
| tree = HierarchicalTree() | |||||
| node_input = None | |||||
| for _, node_name in enumerate(self.nodes_in_topological_order): | |||||
| node_inst = self.get_node(node_name) | |||||
| node_output = self._shape_dict.get(node_name) | |||||
| if node_inst.in_degree == 0: | |||||
| # If in-degree equals to zero, then it's a input node. | |||||
| continue | |||||
| # If the node is on the top, then fetch its input | |||||
| # from input table. | |||||
| if not node_input: | |||||
| node_input = self._input_shape.get(node_name) | |||||
| if not node_input: | |||||
| raise ValueError(f"This model is not supported now. " | |||||
| f"Cannot find {node_name}'s input shape.") | |||||
| tree.insert(node_inst, node_name, node_input, node_output) | |||||
| node_input = node_output | |||||
| return tree | |||||
| def build_connection(self, src, tgt) -> NoReturn: | def build_connection(self, src, tgt) -> NoReturn: | ||||
| """ | """ | ||||
| Build connection between source node and target node. | Build connection between source node and target node. | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define PyTorch graph node.""" | """Define PyTorch graph node.""" | ||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from .torch_utils import getitem_of_node | |||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | ||||
| SEPARATOR_IN_ONNX_OP | SEPARATOR_IN_ONNX_OP | ||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| @@ -202,6 +202,8 @@ class PyTorchGraphNode(GraphNode): | |||||
| Returns: | Returns: | ||||
| dict, raw params. | dict, raw params. | ||||
| """ | """ | ||||
| from .torch_utils import getitem_of_node | |||||
| raw_params = dict() | raw_params = dict() | ||||
| if not node: | if not node: | ||||
| @@ -16,7 +16,6 @@ | |||||
| import importlib | import importlib | ||||
| from torch.nn import Module | from torch.nn import Module | ||||
| from torch.jit import _unique_state_dict | |||||
| from torch.onnx.utils import _trace | from torch.onnx.utils import _trace | ||||
| from torch.onnx.utils import _node_getitem | from torch.onnx.utils import _node_getitem | ||||
| @@ -36,6 +35,8 @@ def unique_state_dict(model): | |||||
| Returns: | Returns: | ||||
| dict, params. | dict, params. | ||||
| """ | """ | ||||
| from torch.jit import _unique_state_dict | |||||
| return _unique_state_dict(model) | return _unique_state_dict(model) | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| @@ -0,0 +1,26 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Test name manager module.""" | |||||
| from unittest import TestCase | |||||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import GlobalVarNameMgr | |||||
| class TestNameMgr(TestCase): | |||||
| """Tester of name mgr.""" | |||||
| def test_global_name_mgr(self): | |||||
| """Test global name mgr.""" | |||||
| name = GlobalVarNameMgr().get_name("onnx::Conv") | |||||
| assert isinstance(name, str) | |||||