| @@ -84,10 +84,11 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| """ | |||
| from .third_party_graph import GraphFactory | |||
| from .hierarchical_tree import HierarchicalTreeFactory | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| checkpoint=checkpoint_path) | |||
| hierarchical_tree = graph_obj.to_hierarchical_tree() | |||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| report_folder=report_folder) | |||
| @@ -16,5 +16,42 @@ | |||
| from .hierarchical_tree import HierarchicalTree | |||
| __all__ = [ | |||
| "HierarchicalTree" | |||
| "HierarchicalTreeFactory" | |||
| ] | |||
| class HierarchicalTreeFactory: | |||
| """Hierarchical tree factory.""" | |||
| @classmethod | |||
| def create(cls, graph): | |||
| """ | |||
| Factory method of hierarchical tree. | |||
| Args: | |||
| graph: Graph obj. | |||
| Returns: | |||
| HierarchicalTree, tree. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_input = None | |||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||
| node_inst = graph.get_node(node_name) | |||
| node_output = graph.get_output_shape(node_name) | |||
| if node_inst.in_degree == 0: | |||
| # If in-degree equals to zero, then it's a input node. | |||
| continue | |||
| # If the node is on the top, then fetch its input | |||
| # from input table. | |||
| if not node_input: | |||
| node_input = graph.get_input_shape(node_name) | |||
| if not node_input: | |||
| raise ValueError(f"This model is not supported now. " | |||
| f"Cannot find {node_name}'s input shape.") | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_input = node_output | |||
| return tree | |||
| @@ -53,10 +53,6 @@ class ModuleNameMgr(NameMgr): | |||
| """Module name manager.""" | |||
| class VariableNameMgrInModule(NameMgr): | |||
| """Variable name mgr for a module.""" | |||
| global_op_namespace = dict() | |||
| START_IDX = 0 | |||
| @@ -15,9 +15,6 @@ | |||
| """Define graph entity.""" | |||
| import abc | |||
| from collections import OrderedDict | |||
| from typing import Dict, Union, Any | |||
| from torch.nn import Module | |||
| from ..constant import SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| @@ -40,21 +37,13 @@ class BaseGraph(metaclass=abc.ABCMeta): | |||
| def build(self, input_shape: tuple): | |||
| """Build graph.""" | |||
| @abc.abstractmethod | |||
| def to_ir(self, mapper): | |||
| """Convert graph to ir graph.""" | |||
| @abc.abstractmethod | |||
| def to_hierarchical_tree(self): | |||
| """Convert to hierarchical tree.""" | |||
| @abc.abstractmethod | |||
| def sub_graph_merging(self): | |||
| """Merge split nodes into one.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| def load_checkpoint(ckpt_path: str) -> dict: | |||
| """Load checkpoint file.""" | |||
| @staticmethod | |||
| @@ -88,15 +77,14 @@ class Graph(BaseGraph, abc.ABC): | |||
| Define Factory method to create Graph sub-class. | |||
| Args: | |||
| model (Union[Module, Any]): Graph file. | |||
| model (Union[torch.nn.Module, Any]): Graph file. | |||
| checkpoint (dict): Checkpoint path. | |||
| """ | |||
| sorted = False | |||
| def __init__(self, model: Union[Module, Any], | |||
| **kwargs): | |||
| def __init__(self, model, **kwargs): | |||
| super(Graph, self).__init__() | |||
| self.model = model | |||
| self.checkpoint = kwargs.get("checkpoint", None) | |||
| @@ -108,6 +96,27 @@ class Graph(BaseGraph, abc.ABC): | |||
| self._topological_order = [] | |||
| self._input_shape = dict() | |||
| def get_output_shape(self, name): | |||
| """ | |||
| Get node output shape. | |||
| Args: | |||
| name (str): Node name. | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| return self._shape_dict.get(name) | |||
| def get_input_shape(self, name): | |||
| """ | |||
| Get node input shape. | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| return self._input_shape.get(name) | |||
| @property | |||
| def nodes_in_topological_order(self): | |||
| """ | |||
| @@ -192,17 +201,11 @@ class Graph(BaseGraph, abc.ABC): | |||
| idx += 1 | |||
| self.sorted = True | |||
| def to_ir(self, mapper): | |||
| raise NotImplementedError | |||
| def to_hierarchical_tree(self): | |||
| raise NotImplementedError | |||
| def sub_graph_merging(self): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def load_checkpoint(ckpt_path: str) -> Dict: | |||
| def load_checkpoint(ckpt_path: str) -> dict: | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| @@ -17,18 +17,11 @@ import warnings | |||
| import re | |||
| from typing import Dict, NoReturn | |||
| import torch | |||
| from torch.nn import Module | |||
| from torch.onnx import OperatorExportTypes | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| from .graph_parser import PyTorchGraphParser | |||
| from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict | |||
| from .torch_utils import create_autograd_variable | |||
| from .torch_utils import onnx_tracer | |||
| from ..hierarchical_tree import HierarchicalTree | |||
| from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| @@ -78,8 +71,11 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| def __init__(self, model: Module, sample_shape: tuple): | |||
| def __init__(self, model, sample_shape: tuple): | |||
| super(PyTorchGraph, self).__init__(model=model) | |||
| from .torch_utils import unique_state_dict | |||
| self._params_dict = unique_state_dict(model) | |||
| self.build(sample_shape) | |||
| @@ -108,6 +104,12 @@ class PyTorchGraph(Graph): | |||
| input_shape (tuple): Input shape of model. | |||
| """ | |||
| import torch | |||
| from torch.onnx import OperatorExportTypes | |||
| from .torch_utils import OverloadTorchModuleTemporarily | |||
| from .torch_utils import create_autograd_variable | |||
| from .torch_utils import onnx_tracer | |||
| self._check_input_shape(input_shape) | |||
| def _extract_shape(shape): | |||
| @@ -188,32 +190,6 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_hierarchical_tree(self): | |||
| """ | |||
| Generate hierarchical tree based on graph. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_input = None | |||
| for _, node_name in enumerate(self.nodes_in_topological_order): | |||
| node_inst = self.get_node(node_name) | |||
| node_output = self._shape_dict.get(node_name) | |||
| if node_inst.in_degree == 0: | |||
| # If in-degree equals to zero, then it's a input node. | |||
| continue | |||
| # If the node is on the top, then fetch its input | |||
| # from input table. | |||
| if not node_input: | |||
| node_input = self._input_shape.get(node_name) | |||
| if not node_input: | |||
| raise ValueError(f"This model is not supported now. " | |||
| f"Cannot find {node_name}'s input shape.") | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_input = node_output | |||
| return tree | |||
| def build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| from .base import GraphNode | |||
| from .torch_utils import getitem_of_node | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| @@ -202,6 +202,8 @@ class PyTorchGraphNode(GraphNode): | |||
| Returns: | |||
| dict, raw params. | |||
| """ | |||
| from .torch_utils import getitem_of_node | |||
| raw_params = dict() | |||
| if not node: | |||
| @@ -16,7 +16,6 @@ | |||
| import importlib | |||
| from torch.nn import Module | |||
| from torch.jit import _unique_state_dict | |||
| from torch.onnx.utils import _trace | |||
| from torch.onnx.utils import _node_getitem | |||
| @@ -36,6 +35,8 @@ def unique_state_dict(model): | |||
| Returns: | |||
| dict, params. | |||
| """ | |||
| from torch.jit import _unique_state_dict | |||
| return _unique_state_dict(model) | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Test name manager module.""" | |||
| from unittest import TestCase | |||
| from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import GlobalVarNameMgr | |||
| class TestNameMgr(TestCase): | |||
| """Tester of name mgr.""" | |||
| def test_global_name_mgr(self): | |||
| """Test global name mgr.""" | |||
| name = GlobalVarNameMgr().get_name("onnx::Conv") | |||
| assert isinstance(name, str) | |||