Browse Source

Solve cycle-ref.

tags/v1.0.0
liuchongming 5 years ago
parent
commit
407a3bd1c8
10 changed files with 135 additions and 65 deletions
  1. +2
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  2. +38
    -1
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  3. +0
    -4
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  4. +25
    -22
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  5. +11
    -35
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  6. +3
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py
  7. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py
  8. +14
    -0
      tests/ut/mindconverter/graph_based_converter/__init__.py
  9. +14
    -0
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  10. +26
    -0
      tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -84,10 +84,11 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,

"""
from .third_party_graph import GraphFactory
from .hierarchical_tree import HierarchicalTreeFactory

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path)
hierarchical_tree = graph_obj.to_hierarchical_tree()
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
report_folder=report_folder)



+ 38
- 1
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -16,5 +16,42 @@
from .hierarchical_tree import HierarchicalTree

__all__ = [
"HierarchicalTree"
"HierarchicalTreeFactory"
]


class HierarchicalTreeFactory:
"""Hierarchical tree factory."""

@classmethod
def create(cls, graph):
"""
Factory method of hierarchical tree.

Args:
graph: Graph obj.

Returns:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_output = graph.get_output_shape(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = graph.get_input_shape(node_name)

if not node_input:
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

+ 0
- 4
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -53,10 +53,6 @@ class ModuleNameMgr(NameMgr):
"""Module name manager."""


class VariableNameMgrInModule(NameMgr):
"""Variable name mgr for a module."""


global_op_namespace = dict()
START_IDX = 0



+ 25
- 22
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -15,9 +15,6 @@
"""Define graph entity."""
import abc
from collections import OrderedDict
from typing import Dict, Union, Any

from torch.nn import Module

from ..constant import SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper
@@ -40,21 +37,13 @@ class BaseGraph(metaclass=abc.ABCMeta):
def build(self, input_shape: tuple):
"""Build graph."""

@abc.abstractmethod
def to_ir(self, mapper):
"""Convert graph to ir graph."""

@abc.abstractmethod
def to_hierarchical_tree(self):
"""Convert to hierarchical tree."""

@abc.abstractmethod
def sub_graph_merging(self):
"""Merge split nodes into one."""

@staticmethod
@abc.abstractmethod
def load_checkpoint(ckpt_path: str) -> Dict:
def load_checkpoint(ckpt_path: str) -> dict:
"""Load checkpoint file."""

@staticmethod
@@ -88,15 +77,14 @@ class Graph(BaseGraph, abc.ABC):
Define Factory method to create Graph sub-class.

Args:
model (Union[Module, Any]): Graph file.
model (Union[torch.nn.Module, Any]): Graph file.
checkpoint (dict): Checkpoint path.

"""

sorted = False

def __init__(self, model: Union[Module, Any],
**kwargs):
def __init__(self, model, **kwargs):
super(Graph, self).__init__()
self.model = model
self.checkpoint = kwargs.get("checkpoint", None)
@@ -108,6 +96,27 @@ class Graph(BaseGraph, abc.ABC):
self._topological_order = []
self._input_shape = dict()

def get_output_shape(self, name):
"""
Get node output shape.

Args:
name (str): Node name.

Returns:
list, shape.
"""
return self._shape_dict.get(name)

def get_input_shape(self, name):
"""
Get node input shape.

Returns:
list, shape.
"""
return self._input_shape.get(name)

@property
def nodes_in_topological_order(self):
"""
@@ -192,17 +201,11 @@ class Graph(BaseGraph, abc.ABC):
idx += 1
self.sorted = True

def to_ir(self, mapper):
raise NotImplementedError

def to_hierarchical_tree(self):
raise NotImplementedError

def sub_graph_merging(self):
raise NotImplementedError

@staticmethod
def load_checkpoint(ckpt_path: str) -> Dict:
def load_checkpoint(ckpt_path: str) -> dict:
raise NotImplementedError

@staticmethod


+ 11
- 35
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -17,18 +17,11 @@ import warnings
import re
from typing import Dict, NoReturn

import torch
from torch.nn import Module
from torch.onnx import OperatorExportTypes

from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .graph_parser import PyTorchGraphParser
from .torch_utils import OverloadTorchModuleTemporarily, unique_state_dict
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer
from ..hierarchical_tree import HierarchicalTree

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

@@ -78,8 +71,11 @@ class PyTorchGraph(Graph):

"""

def __init__(self, model: Module, sample_shape: tuple):
def __init__(self, model, sample_shape: tuple):
super(PyTorchGraph, self).__init__(model=model)

from .torch_utils import unique_state_dict

self._params_dict = unique_state_dict(model)
self.build(sample_shape)

@@ -108,6 +104,12 @@ class PyTorchGraph(Graph):
input_shape (tuple): Input shape of model.

"""
import torch
from torch.onnx import OperatorExportTypes
from .torch_utils import OverloadTorchModuleTemporarily
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

self._check_input_shape(input_shape)

def _extract_shape(shape):
@@ -188,32 +190,6 @@ class PyTorchGraph(Graph):
"""
raise NotImplementedError()

def to_hierarchical_tree(self):
"""
Generate hierarchical tree based on graph.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(self.nodes_in_topological_order):
node_inst = self.get_node(node_name)
node_output = self._shape_dict.get(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = self._input_shape.get(node_name)

if not node_input:
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.


+ 3
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -14,7 +14,7 @@
# ==============================================================================
"""Define PyTorch graph node."""
from .base import GraphNode
from .torch_utils import getitem_of_node
from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper
@@ -202,6 +202,8 @@ class PyTorchGraphNode(GraphNode):
Returns:
dict, raw params.
"""
from .torch_utils import getitem_of_node

raw_params = dict()

if not node:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -16,7 +16,6 @@
import importlib

from torch.nn import Module
from torch.jit import _unique_state_dict
from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem

@@ -36,6 +35,8 @@ def unique_state_dict(model):
Returns:
dict, params.
"""
from torch.jit import _unique_state_dict

return _unique_state_dict(model)




+ 14
- 0
tests/ut/mindconverter/graph_based_converter/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

+ 14
- 0
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

+ 26
- 0
tests/ut/mindconverter/graph_based_converter/hierarchical_tree/test_name_mgr.py View File

@@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test name manager module."""
from unittest import TestCase
from mindinsight.mindconverter.graph_based_converter.hierarchical_tree.name_mgr import GlobalVarNameMgr


class TestNameMgr(TestCase):
"""Tester of name mgr."""

def test_global_name_mgr(self):
"""Test global name mgr."""
name = GlobalVarNameMgr().get_name("onnx::Conv")
assert isinstance(name, str)

Loading…
Cancel
Save