Browse Source

!640 Generalization of MindConverter

Merge pull request !640 from 刘崇鸣/generalize_mindconverter
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1b95f3b6d1
11 changed files with 239 additions and 146 deletions
  1. +8
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  2. +5
    -13
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  3. +65
    -64
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  4. +18
    -8
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  5. +6
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  6. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  7. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py
  8. +14
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  9. +23
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  10. +76
    -44
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  11. +22
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py

+ 8
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -20,6 +20,7 @@ from importlib.util import find_spec
import mindinsight
from mindinsight.mindconverter.common.log import logger as log
from .mapper import ONNXToMindSporeMapper
from ..common.exceptions import NodeTypeNotSupport

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
@@ -92,7 +93,13 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path)
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
try:
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
except Exception as e:
log.exception(e)
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
report_folder=report_folder)



+ 5
- 13
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Hierarchical tree module."""
from mindinsight.mindconverter.common.log import logger as log
from .hierarchical_tree import HierarchicalTree

__all__ = [
@@ -35,23 +36,14 @@ class HierarchicalTreeFactory:
HierarchicalTree, tree.
"""
tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(graph.nodes_in_topological_order):
node_inst = graph.get_node(node_name)
node_input = graph.get_input_shape(node_name)
node_output = graph.get_output_shape(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = graph.get_input_shape(node_name)

if not node_input:
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
log.error(err_msg)

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

+ 65
- 64
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -31,6 +31,7 @@ from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVE
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT
from ..constant import NodeType
from ..report_generator import ReportGenerator
from ...common.exceptions import NodeTypeNotSupport

GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()

@@ -56,7 +57,7 @@ class HierarchicalTree(Tree):
# Manage module name to used.
self._module_mgr = ModuleNameMgr()
# Manage variable name in a module.
self._args_mgr_in_module = dict()
self._vars_mgr_in_module = dict()
self._module_vars = dict()

@property
@@ -86,7 +87,7 @@ class HierarchicalTree(Tree):
parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])
identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])
try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \
if not parent else scope
if parent else scope
if self.contains(try_parent):
# Whether current node existed.
parent = try_parent
@@ -132,6 +133,8 @@ class HierarchicalTree(Tree):
"""
Shrink sub-tree into one node.

Use child node to replace its ancestor.

Args:
node (Node): List of nodes to be merged.

@@ -140,6 +143,8 @@ class HierarchicalTree(Tree):
parent_node = self[node.predecessor(self.tree_identifier)]
# Keep successors of parent.
brothers = deepcopy(parent_node.successors(self.tree_identifier))
# Because shrink occurs when node has only one child,
# so we take index-0.
child = node.successors(self.tree_identifier)[0]
self.move_node(source=child,
destination=node.predecessor(self.tree_identifier))
@@ -158,9 +163,13 @@ class HierarchicalTree(Tree):
out_folder (str): Output folder.

"""
self._adjust_structure()

code_fragments = self._generate_codes(mapper)
try:
self._adjust_structure()
code_fragments = self._generate_codes(mapper)
except Exception as e:
log.exception(e)
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

out_folder = os.path.abspath(out_folder)
if not report_folder:
@@ -176,9 +185,8 @@ class HierarchicalTree(Tree):
for file_name in code_fragments:
code, report = code_fragments[file_name]
try:
with os.fdopen(
os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"),
self.flags, self.modes), 'w') as file:
with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"),
self.flags, self.modes), "w") as file:
file.write(code)
except IOError as error:
log.error(str(error))
@@ -186,9 +194,8 @@ class HierarchicalTree(Tree):
raise error

try:
with os.fdopen(
os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"),
self.flags, stat.S_IRUSR), "w") as rpt_f:
with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"),
self.flags, stat.S_IRUSR), "w") as rpt_f:
rpt_f.write(report)
except IOError as error:
log.error(str(error))
@@ -223,7 +230,8 @@ class HierarchicalTree(Tree):
Returns:
Node, node.
"""
if node.data.node_type == NodeType.MODULE.value:
if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value,
NodeType.FUNC.value}:
# If current node is class or function, then
# remove unused args in __init__.
cur_module_key = node.data.hash_key or self.hash_key(node)
@@ -231,13 +239,15 @@ class HierarchicalTree(Tree):
node = self._clear_unused_args(node,
self._merged_module_args[cur_module_key])

# `self._merged_module_args` records formal args.
# We need to replace actual args.
if precursor_module_key in self._merged_module_args:
# If parent node is in `_merged_module_args`, then
# replace current node args with arg name declared
# in _merged_module_args.
for arg in node.data.args_in_code.keys():
if arg in self._merged_module_args[precursor_module_key]:
node.data.replace_with_arg(arg)
node.data.replace_with_arg(arg, arg)
return node

@staticmethod
@@ -254,7 +264,8 @@ class HierarchicalTree(Tree):
"""
args_in_code = list(node.data.args_in_code.keys())
for arg in args_in_code:
if arg not in used_args:
ori_arg = arg.replace(f"_{node.data.variable_name}", "")
if ori_arg not in used_args:
node.data.args_in_code.pop(arg)
return node

@@ -287,9 +298,11 @@ class HierarchicalTree(Tree):
if node.data.node_type == NodeType.MODULE.value:
self._create_module_args_and_vars(node, mapper)

# 2. Get nodes can be merged.
self._module_merging(node_collection)
# Module merging based on all nodes.
self._module_merging()

for depth in depths:
node_collection = self._hierarchical_order[depth]
snippets = set()
for node_name in node_collection:
nd_inst = self.get_node(node_name)
@@ -297,8 +310,7 @@ class HierarchicalTree(Tree):
continue

# Generate hash key for node.
module_key = self.hash_key(nd_inst)

module_key = nd_inst.data.hash_key
# Get code generation func.
func, node_type = self._fetch_func_and_type(nd_inst)

@@ -325,9 +337,8 @@ class HierarchicalTree(Tree):
# 3. Pre-process node args.
nd_inst = self._preprocess_node_args(nd_inst, module_key)
# 4. Post-process child node args.
for scsr_nd_name in nd_inst.successors(self.tree_identifier):
self._postprocess_node_args(self.get_node(scsr_nd_name),
module_key)
for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)):
self._postprocess_node_args(self.get_node(scsr_nd_name), module_key)
# 5. Generate code.
snippets.add(func(nd_inst, nd_inst.data.module_name, module_key))

@@ -335,7 +346,6 @@ class HierarchicalTree(Tree):

formatted_code, _ = FormatCode("".join(code_blocks),
style_config=CodeFormatConfig.PEP8.value)

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)

@@ -403,9 +413,9 @@ class HierarchicalTree(Tree):
"output_shape": c_nd.data.output_shape})

# Generate code statement.
expr = ", ".join([f"{k}={v}" for k, v in args.items()])
expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}"
for k, v in args.items()])
code_line = f"{operator}({expr})"

module_list.append(code_line)

body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list)
@@ -435,6 +445,7 @@ class HierarchicalTree(Tree):
if class_key.lower() in self._merged_module_args and \
self._merged_module_args[class_key.lower()]:
args = f"{', '.join(self._merged_module_args[class_key.lower()])}"

class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \
f"{args}):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \
@@ -455,7 +466,8 @@ class HierarchicalTree(Tree):
construct_block.append(construct)
init_block.append(init)

class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}"
class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \
f"{NEW_LINE}{SECOND_LEVEL_INDENT}"
init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block)
csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block)
csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}"
@@ -514,7 +526,7 @@ class HierarchicalTree(Tree):

def _find_all_previous_opt_var_(self, cur_nd, pre_nd):
"""
Find all input varian names.
Find all input variable names.

Args:
cur_nd (Node): Current node.
@@ -585,61 +597,46 @@ class HierarchicalTree(Tree):
node.data.hash_key = unique_key
return unique_key

def _module_merging(self, nodes):
"""
Generate sub-module and corresponding params.

Args:
nodes (List[str]): Nodes name.

"""
merged_module = dict()
def _module_merging(self):
"""Generate sub-module and corresponding params."""
merged_module_args = dict()
for node_name in nodes:
nd_inst = self.get_node(node_name)
if nd_inst.data.node_type != NodeType.MODULE.value:
continue

module_key = self.hash_key(nd_inst)
if module_key not in merged_module:
merged_module[module_key] = [nd_inst.data.args_in_code]
else:
merged_module[module_key].append(nd_inst.data.args_in_code)

for module_key, module_args in merged_module.items():
for module_key, module_args in self._merged_module.items():
if module_key not in merged_module_args:
merged_module_args[module_key] = []
# Take first element's args as base.
keys = module_args[0].keys()
for key in keys:
for i in range(1, len(module_args)):
if module_args[0][key] != module_args[i][key]:
if key in module_args[i] and module_args[0][key] != module_args[i][key]:
merged_module_args[module_key].append(key)
break
if key not in module_args[i]:
merged_module_args[module_key].append(key)
break

self._merged_module.update(merged_module)
self._merged_module_args.update(merged_module_args)

def _create_module_args_and_vars(self, node, mapper):
"""
Create module args.
Create module args and variables in current node.

Args:
node (Node): Node on tree.
mapper (Mapper): Mapper of params.

"""
# All args and value pair in current node module.
module_args = dict()
module_key = self.hash_key(node)

created = False

if module_key not in self._args_mgr_in_module:
self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR
if module_key not in self._vars_mgr_in_module:
self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR
self._module_vars[module_key] = []
else:
created = True

# Sub-modules in the module could have arg name conflicts.
for idx, successor_name in enumerate(node.successors(self.tree_identifier)):
nd_inst = self.get_node(successor_name)
# Generate variable name here, then
@@ -648,12 +645,11 @@ class HierarchicalTree(Tree):
nd_inst.data.variable_name = self._module_vars[module_key][idx]
else:
variable_name = nd_inst.data.op_name or nd_inst.data.module_name
variable_name = self._args_mgr_in_module[module_key].get_name(variable_name)
variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name)
nd_inst.data.variable_name = variable_name

if nd_inst.data.node_type == NodeType.OPERATION.value:
# Generation of params must behind variable assigment.
nd_inst.data.param_transform(mapper)
# Generation of params must behind variable assigment.
nd_inst.data.param_transform(mapper)

module_args.update(nd_inst.data.args_in_code)

@@ -662,6 +658,12 @@ class HierarchicalTree(Tree):

node.data.args_in_code = module_args

# Collect module args of `module_key`.
if module_key not in self._merged_module:
self._merged_module[module_key] = [node.data.args_in_code]
else:
self._merged_module[module_key].append(node.data.args_in_code)

@staticmethod
def _create_operation_args(node, mapper):
"""
@@ -692,21 +694,20 @@ class HierarchicalTree(Tree):
self._hierarchical_order = hierarchical_order

def sub_graph_merging(self) -> NoReturn:
"""
Shrink subtree.
"""
"""Shrink the module has only one child."""
self.update_hierarchical_order()
depths = sorted(list(self._hierarchical_order.keys()), reverse=True)
for depth in depths:
for node_name in self._hierarchical_order[depth]:
node_inst = self[node_name]
if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1:
# If the node type is module and has only one child,
# then merge it with its child.
if node_inst.data.node_type == NodeType.MODULE.value and \
len(node_inst.successors(self.tree_identifier)) == 1:
self.shrink(node_inst)

def _adjust_structure(self) -> NoReturn:
"""
Adjust tree structure to generate source code.
"""
"""Adjust tree structure to generate source code."""
self.sub_graph_merging()
self.update_hierarchical_order()



+ 18
- 8
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -53,6 +53,9 @@ class ModuleNameMgr(NameMgr):
"""Module name manager."""


# Manage variable name of different modules.
global_var_namespace = set()
# Manage variable name of different type.
global_op_namespace = dict()
START_IDX = 0

@@ -81,14 +84,21 @@ class GlobalVarNameMgr:
Returns:
str, module name.
"""
op_type = op_type.lower()
if op_type not in global_op_namespace:
global_op_namespace[op_type] = START_IDX
suffix = ""
else:
global_op_namespace[op_type] += 1
suffix = f"{global_op_namespace[op_type] - 1}"

new_name = f"{self._get_name(op_type)}{suffix}"
def _gen(t):
t = t.lower()
if t not in global_op_namespace:
global_op_namespace[t] = START_IDX
suffix = ""
else:
global_op_namespace[t] += 1
suffix = f"{global_op_namespace[t] - 1}"

return f"{self._get_name(t)}{suffix}"

new_name = _gen(op_type)
while new_name in global_var_namespace:
new_name = _gen(op_type)

global_var_namespace.add(new_name)
return new_name

+ 6
- 3
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -18,6 +18,7 @@ import importlib
import json
import os
from typing import Dict
from mindinsight.mindconverter.common.log import logger as log

CONFIG_JSON = "onnx_to_ms.json"
OPERATION_TABLE = os.path.join(
@@ -91,7 +92,8 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
weights_converter = getattr(converter, GET_OP_WEIGHTS)
except (ModuleNotFoundError,) as e:
# If mapper can not be found, then skip it.
print(f"Converting {op_name} failed, see {e}")
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict()

try:
@@ -99,8 +101,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
converted_params = params_converter(params, weights)
converted_weights = weights_converter(weights) if weights else dict()
converted_params.update(converted_weights)
except (AttributeError, KeyError, ValueError, TypeError) as _:
print(f"Converting {op_name} failed.")
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict()

return converter_name, converted_params


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -39,7 +39,7 @@ class ConvMapper(ONNXToMindSporeMapper):
else:
stride = params['strides']
kernel_shape = list(weight.shape)
in_channels = kernel_shape[-2]
in_channels = kernel_shape[-2] * params.get("group", 1)
out_channels = kernel_shape[-1]
kernel_size = kernel_shape[:-2]
if len(kernel_size) == 1:


+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py View File

@@ -31,8 +31,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_params(params, weights):
dim = 1 if len(params['input_shape']) == 3\
else 2
dim = 1 if len(params['input_shape']) == 3 else 2
if dim == 1:
kernel_size = params['input_shape'][-1] // params['output_shape'][-1]
else:


+ 14
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -100,6 +100,18 @@ class Graph(BaseGraph, abc.ABC):
self._topological_order = []
self._input_shape = dict()

def get_input_shape(self, name):
"""
Get node input shape.

Args:
name (str): Node name.

Returns:
list, shape.
"""
return self._input_shape.get(name)

def get_output_shape(self, name):
"""
Get node output shape.
@@ -112,7 +124,7 @@ class Graph(BaseGraph, abc.ABC):
"""
return self._shape_dict.get(name)

def get_input_shape(self, name):
def get_input_shape_from_input(self, name):
"""
Get node input shape.

@@ -482,7 +494,7 @@ class GraphNode(abc.ABC):
"""Return op_name."""

@abc.abstractmethod
def replace_with_arg(self, arg):
def replace_with_arg(self, src_arg, tgt_arg):
"""Replace actual parameter with formal parameter."""

@abc.abstractmethod


+ 23
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -53,7 +53,7 @@ class InputNode(GraphNode):
def hash_key(self):
pass

def replace_with_arg(self, arg):
def replace_with_arg(self, src_arg, tgt_arg):
pass

def _get_arg_name(self, arg):
@@ -65,9 +65,30 @@ class InputNode(GraphNode):
def __init__(self, input_shape):
super(InputNode, self).__init__(node=None)
self._op_name = 'Input'
self._op_params = {'node_shape': input_shape}
self._op_params = {'input_shape': input_shape,
"output_shape": input_shape}
self._node_type = NodeType.INPUT.value

@property
def input_shape(self):
"""
Input tensor shape of current node.

Returns:
tuple, tensor shape of input.
"""
return self._op_params["input_shape"]

@property
def output_shape(self):
"""
Output tensor shape.

Returns:
tuple, output tensor shape.
"""
return self._op_params["output_shape"]

def set_scope_name(self, original_input_scope_name):
"""
Set scope name.


+ 76
- 44
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import warnings
import re
from typing import Dict, NoReturn

@@ -27,8 +26,11 @@ from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

NONE_SCOPE_OP = {
'onnx::Add': 'Add',
'onnx::Flatten': 'Flatten',
"onnx::Add": "Add",
"onnx::Flatten": "Flatten",
"onnx::Concat": "Concat",
"onnx::Squeeze": "Squeeze",
"onnx::Unsqueeze": "Unsqueeze",
}


@@ -59,6 +61,7 @@ def normalize_scope_name(node):
scopes.append(segment)
if node.kind() in NONE_SCOPE_OP.keys():
scopes.append(NONE_SCOPE_OP[node.kind()])
scopes = [s for s in scopes if s]
return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}"


@@ -90,18 +93,16 @@ class PyTorchGraph(Graph):

"""
if not input_shape:
error = ValueError("`input_shape` can not be None.")
log.error(str(error))
log.exception(error)
raise error
err_msg = "`input_shape` can not be None."
log.error(err_msg)
raise ValueError(err_msg)

for item in input_shape:
if not isinstance(item, int):
error = ValueError(f"Only support model with one input now, "
f"and each shape value in `input_shape` should be int.")
log.error(str(error))
log.exception(error)
raise error
err_msg = f"Only support model with one input now, " \
f"and each shape value in `input_shape` should be int."
log.error(err_msg)
raise ValueError(err_msg)

@staticmethod
def _extract_shape(shape):
@@ -116,18 +117,29 @@ class PyTorchGraph(Graph):
"""
if "," not in shape:
return []

shape_arr = []
for s in shape.split(","):
s = s.strip()
if not s:
return []
return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')]
if ":" in s:
s = s.split(":")[0]
s = s.replace("!", "")
if not s.isdigit():
return []
shape_arr.append(int(s))
return shape_arr

def build(self, input_shape):
def _trace_torch_graph(self, input_shape):
"""
Build graph tree.
Trace torch computational graph.

Args:
input_shape (tuple): Input shape of model.
input_shape (tuple): Shape.

Returns:
object, pytorch graph.
"""
import torch
from torch.onnx import OperatorExportTypes
@@ -135,24 +147,34 @@ class PyTorchGraph(Graph):
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

self._check_input_shape(input_shape)

feed_forward_ipt_shape = (1, *input_shape)
batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape))

# Assign execution mode to eval.
self.model.eval()
batched_sample = create_autograd_variable(torch.rand(*input_shape))

try:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph
except RuntimeError as error:
log.error(str(error))
log.exception(error)
raise error

def build(self, input_shape):
"""
Build graph tree.

Args:
input_shape (tuple): Input shape of model.

"""
self._check_input_shape(input_shape)

feed_forward_ipt_shape = (1, *input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())

for node in nodes:
@@ -174,24 +196,43 @@ class PyTorchGraph(Graph):

for node_input in list(node.inputs()):
# Connect input node and src node.
if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName():
nd_id = PyTorchGraph.get_node_id(node_input.node())
nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \
node_input.node().scopeName()

if nd_id and nd_scope_name:
node_input_name = normalize_scope_name(
node_input.node()
)
self.build_connection(node_input_name, node_name)

super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)

# Add Input Node
def _collect_ipt_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.

Args:
input_shape (tuple): Input shape.

"""
input_node = InputNode(input_shape)
input_node_name = "{}InputNode"
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)
input_node.set_scope_name(node.scope_name)
node.precursor_nodes.append(input_node.scope_name)
node.precursor_nodes.insert(0, ipt_nd_name)
input_node.set_successor_nodes(node_name)
self._nodes_collection[input_node.scope_name] = input_node
self._input_shape[node_name] = feed_forward_ipt_shape
break
self._shape_dict[ipt_nd_name] = input_node.output_shape

ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp))

self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape

def sub_graph_merging(self):
"""
@@ -199,12 +240,6 @@ class PyTorchGraph(Graph):
"""
raise NotImplementedError()

def to_ir(self, mapper):
"""
Convert graph to IR graph.
"""
raise NotImplementedError()

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.
@@ -215,13 +250,11 @@ class PyTorchGraph(Graph):

"""
# If src and tgt are the same node, src not in node_collection or
# tgt not in node_collection,
# then skip this edge.
# tgt not in node_collection, then skip this edge.
if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection:
if src.split(':')[0] not in self._nodes_collection:
warnings.warn(f"Graph construct a self-loop node {src}. Ignored.")
log.warning("Graph construct a self-loop node %s. Ignored.", src)
return

if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes:
self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt)
if src not in self._nodes_collection[tgt].precursor_nodes:
@@ -244,11 +277,10 @@ class PyTorchGraph(Graph):
"""
Load graph metadata.
"""
error = NotImplementedError("class `PyTorchGraph` has not implemented "
"`load_metadata()`.")
log.error(str(error))
log.exception(error)
raise error
err_msg = "class `PyTorchGraph` has not implemented " \
"`load_metadata()`."
log.error(err_msg)
raise NotImplementedError(err_msg)

@staticmethod
def load_graph(graph_path: str):


+ 22
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph node."""
from copy import deepcopy

from .base import GraphNode

from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
@@ -140,7 +142,7 @@ class PyTorchGraphNode(GraphNode):
Returns:
str, op name.
"""
return self._op_name # if self.is_empty() else self.tag
return self._op_name

@property
def real_name(self):
@@ -177,8 +179,14 @@ class PyTorchGraphNode(GraphNode):
args.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
if self._node_type == NodeType.OPERATION.value:
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
else:
# When it's type is module, class or func,
# it's not necessary to replace var.
expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}"
for k, v in args.items()])
declare = f"self.{self._variable_name} = {operator}({expr})"
call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})"

@@ -211,15 +219,16 @@ class PyTorchGraphNode(GraphNode):
raw_params[k] = getitem_of_node(node, k)
return raw_params

def replace_with_arg(self, arg):
def replace_with_arg(self, src_arg, tgt_arg):
"""
Replace actual parameter with formal parameter.

Args:
arg (str): Arg name.
src_arg (str): Original arg name.
tgt_arg (str): Target arg name.

"""
self._args_in_code[arg] = arg
self._args_in_code[src_arg] = tgt_arg

@staticmethod
def _extract_var_name(scope_name: str):
@@ -241,6 +250,13 @@ class PyTorchGraphNode(GraphNode):
mapper (Mapper): Mapper of params.

"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg)] = value
return None, None

if not self.transformed:
_, _ = super(PyTorchGraphNode, self).param_transform(mapper)



Loading…
Cancel
Save