Browse Source

Solve cycle-ref.

tags/v1.0.0
liuchongming 5 years ago
parent
commit
407a3bd1c8
10 changed files with 135 additions and 65 deletions
  1. +2
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  2. +38
    -1
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  3. +0
    -4
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  4. +25
    -22
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  5. +11
    -35
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  6. +3
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  7. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py
  8. +14
    -0
      tests/ut/mindconverter/graph_based_converter/__init__.py
  9. +14
    -0
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  10. +26
    -0
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -84,10 +84,11 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,


""" """
from .third_party_graph import GraphFactory from .third_party_graph import GraphFactory
from .hierarchical_tree import HierarchicalTreeFactory


graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path) checkpoint=checkpoint_path)
hierarchical_tree = graph_obj.to_hierarchical_tree()
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
report_folder=report_folder) report_folder=report_folder)




+ 38
- 1
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -16,5 +16,42 @@
from .hierarchical_tree import HierarchicalTree from .hierarchical_tree import HierarchicalTree


__all__ = [ __all__ = [
"HierarchicalTree"
"HierarchicalTreeFactory"
] ]


class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@classmethod
def create(cls, graph):
"""
Factory method of hierarchical tree.

Args:
graph: Graph obj.

Returns:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_output = graph.get_output_shape(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = graph.get_input_shape(node_name)

if not node_input:
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

+ 0
- 4
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -53,10 +53,6 @@ class ModuleNameMgr(NameMgr):
"""Module name manager.""" """Module name manager."""




class VariableNameMgrInModule(NameMgr):
"""Variable name mgr for a module."""


global_op_namespace = dict() global_op_namespace = dict()
START_IDX = 0 START_IDX = 0




+ 25
- 22
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -15,9 +15,6 @@
"""Define graph entity.""" """Define graph entity."""
import abc import abc
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Union, Any

from torch.nn import Module


from ..constant import SEPARATOR_IN_ONNX_OP from ..constant import SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper from ..mapper.base import Mapper
@@ -40,21 +37,13 @@ class BaseGraph(metaclass=abc.ABCMeta):
def build(self, input_shape: tuple): def build(self, input_shape: tuple):
"""Build graph.""" """Build graph."""


@abc.abstractmethod
def to_ir(self, mapper):
"""Convert graph to ir graph."""

@abc.abstractmethod
def to_hierarchical_tree(self):
"""Convert to hierarchical tree."""

@abc.abstractmethod @abc.abstractmethod
def sub_graph_merging(self): def sub_graph_merging(self):
"""Merge split nodes into one.""" """Merge split nodes into one."""


@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def load_checkpoint(ckpt_path: str) -> Dict:
def load_checkpoint(ckpt_path: str) -> dict:
"""Load checkpoint file.""" """Load checkpoint file."""


@staticmethod @staticmethod
@@ -88,15 +77,14 @@ class Graph(BaseGraph, abc.ABC):
Define Factory method to create Graph sub-class. Define Factory method to create Graph sub-class.


Args: Args:
model (Union[Module, Any]): Graph file.
model (Union[torch.nn.Module, Any]): Graph file.
checkpoint (dict): Checkpoint path. checkpoint (dict): Checkpoint path.


""" """


sorted = False sorted = False


def __init__(self, model: Union[Module, Any],
**kwargs):
def __init__(self, model, **kwargs):
super(Graph, self).__init__() super(Graph, self).__init__()
self.model = model self.model = model
self.checkpoint = kwargs.get("checkpoint", None) self.checkpoint = kwargs.get("checkpoint", None)
@@ -108,6 +96,27 @@ class Graph(BaseGraph, abc.ABC):
self._topological_order = [] self._topological_order = []
self._input_shape = dict() self._input_shape = dict()


def get_output_shape(self, name):
"""
Get node output shape.

Args:
name (str): Node name.

Returns:
list, shape.
"""
return self._shape_dict.get(name)

def get_input_shape(self, name):
"""
Get node input shape.

Returns:
list, shape.
"""
return self._input_shape.get(name)

@property @property
def nodes_in_topological_order(self): def nodes_in_topological_order(self):
""" """
@@ -192,17 +201,11 @@ class Graph(BaseGraph, abc.ABC):
idx += 1 idx += 1
self.sorted = True self.sorted = True


def to_ir(self, mapper):
raise NotImplementedError

def to_hierarchical_tree(self):
raise NotImplementedError

def sub_graph_merging(self): def sub_graph_merging(self):
raise NotImplementedError raise NotImplementedError


@staticmethod @staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
def load_checkpoint(ckpt_path: str) -> dict:
raise NotImplementedError raise NotImplementedError


@staticmethod @staticmethod


+ 11
- 35
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -17,18 +17,11 @@ import warnings
import re import re
from typing import Dict, NoReturn from typing import Dict, NoReturn


import torch
from torch.nn import Module
from torch.onnx import OperatorExportTypes

from .base import Graph from .base import Graph
from .input_node import InputNode from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_node import PyTorchGraphNode
from .graph_parser import PyTorchGraphParser from .graph_parser import PyTorchGraphParser
from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer
from ..hierarchical_tree import HierarchicalTree

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import LEFT_BUCKET, RIGHT_BUCKET from ..constant import LEFT_BUCKET, RIGHT_BUCKET


@@ -78,8 +71,11 @@ class PyTorchGraph(Graph):


""" """


def __init__(self, model: Module, sample_shape: tuple):
def __init__(self, model, sample_shape: tuple):
super(PyTorchGraph, self).__init__(model=model) super(PyTorchGraph, self).__init__(model=model)

from .torch_utils import unique_state_dict

self._params_dict = unique_state_dict(model) self._params_dict = unique_state_dict(model)
self.build(sample_shape) self.build(sample_shape)


@@ -108,6 +104,12 @@ class PyTorchGraph(Graph):
input_shape (tuple): Input shape of model. input_shape (tuple): Input shape of model.


""" """
import torch
from torch.onnx import OperatorExportTypes
from .torch_utils import OverloadTorchModuleTemporarily
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

self._check_input_shape(input_shape) self._check_input_shape(input_shape)


def _extract_shape(shape): def _extract_shape(shape):
@@ -188,32 +190,6 @@ class PyTorchGraph(Graph):
""" """
raise NotImplementedError() raise NotImplementedError()


def to_hierarchical_tree(self):
"""
Generate hierarchical tree based on graph.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(self.nodes_in_topological_order):
node_inst = self.get_node(node_name)
node_output = self._shape_dict.get(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = self._input_shape.get(node_name)

if not node_input:
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

def build_connection(self, src, tgt) -> NoReturn: def build_connection(self, src, tgt) -> NoReturn:
""" """
Build connection between source node and target node. Build connection between source node and target node.


+ 3
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define PyTorch graph node.""" """Define PyTorch graph node."""
from .base import GraphNode from .base import GraphNode
from .torch_utils import getitem_of_node
from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper from ..mapper.base import Mapper
@@ -202,6 +202,8 @@ class PyTorchGraphNode(GraphNode):
Returns: Returns:
dict, raw params. dict, raw params.
""" """
from .torch_utils import getitem_of_node

raw_params = dict() raw_params = dict()


if not node: if not node:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -16,7 +16,6 @@
import importlib import importlib


from torch.nn import Module from torch.nn import Module
from torch.jit import _unique_state_dict
from torch.onnx.utils import _trace from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem from torch.onnx.utils import _node_getitem


@@ -36,6 +35,8 @@ def unique_state_dict(model):
Returns: Returns:
dict, params. dict, params.
""" """
from torch.jit import _unique_state_dict

return _unique_state_dict(model) return _unique_state_dict(model)






+ 14
- 0
tests/ut/mindconverter/graph_based_converter/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

+ 14
- 0
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

+ 26
- 0
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py View File

@@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test name manager module."""
from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import GlobalVarNameMgr


class TestNameMgr(TestCase):
"""Tester of name mgr."""

def test_global_name_mgr(self):
"""Test global name mgr."""
name = GlobalVarNameMgr().get_name("onnx::Conv")
assert isinstance(name, str)

Loading…
Cancel
Save