Browse Source

!972 Optimize input_shape check & Add function to process multi-input and multi-output in pytorch.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0302369b3b
4 changed files with 84 additions and 17 deletions
  1. +2
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  3. +77
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  4. +4
    -14
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 2
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -50,6 +50,8 @@ ARGUMENT_LENGTH_LIMIT = 512


EXPECTED_NUMBER = 1 EXPECTED_NUMBER = 1


MIN_SCOPE_LENGTH = 2



@unique @unique
class CodeFormatConfig(Enum): class CodeFormatConfig(Enum):


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -520,7 +520,7 @@ class GraphNode(abc.ABC):
if input_type == InputType.TENSOR.value: if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value: elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
ipt_args_settings_in_construct = f"({ipt_args_in_construct},)"
else: else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
else: else:


+ 77
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define PyTorch graph.""" """Define PyTorch graph."""
import re import re
from copy import deepcopy
from typing import Dict, NoReturn from typing import Dict, NoReturn


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
@@ -22,7 +23,8 @@ from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser from .pytorch_graph_parser import PyTorchGraphParser


from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH
from ..constant import LEFT_BUCKET, RIGHT_BUCKET from ..constant import LEFT_BUCKET, RIGHT_BUCKET


NONE_SCOPE_OP = { NONE_SCOPE_OP = {
@@ -206,6 +208,8 @@ class PyTorchGraph(Graph):
) )
self.build_connection(node_input_name, node_name) self.build_connection(node_input_name, node_name)


self._unmerge_multi_ipt_opt_script()

super(PyTorchGraph, self).build(input_shape=input_shape) super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)


@@ -227,13 +231,84 @@ class PyTorchGraph(Graph):
input_node.set_successor_nodes(node_name) input_node.set_successor_nodes(node_name)
self._shape_dict[ipt_nd_name] = input_node.output_shape self._shape_dict[ipt_nd_name] = input_node.output_shape


if not self._shape_dict[node_name]:
self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE

ipt_shape = [] ipt_shape = []
for p_nd in node.precursor_nodes: for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd) shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp))
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)


self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape


def _generate_module(self):
"""Generate modules."""
module_dict = dict()
for node_key, _ in self._nodes_collection.items():
node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE)
if len(node_key_in_scope) < MIN_SCOPE_LENGTH:
continue

for idx in range(1, len(node_key_in_scope)):
node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx])
node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1])
if not module_dict.get(node_key_module, None):
module_dict[node_key_module] = {node_name}
else:
module_dict[node_key_module].add(node_name)

return module_dict

def _check_multi_ipt(self):
"""Check whether multi-input exists."""
module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set()
for node_name in nodes_per_module:
node = self._nodes_collection.get(node_name, None)
if node:
prcs_nodes = node.precursor_nodes
else:
continue

for prcs_node in prcs_nodes:
if prcs_node not in nodes_per_module:
prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1])
if prcs_node_module not in nodes_per_module:
prcs_nodes_out_from_module.add(prcs_node)

if len(prcs_nodes_out_from_module) > 1:
return True

return False

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt():
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)]
node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0],
scsr_node.split(SEPARATOR_IN_SCOPE)[-1]))
for scsr_node in deepcopy(scsr_nodes)]

reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._nodes_collection[node_key]
self._nodes_collection[reduce_node_key] = node_inst

for node_key, shape in deepcopy(self._shape_dict).items():
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._shape_dict[node_key]
self._shape_dict[reduce_node_key] = shape

def sub_graph_merging(self): def sub_graph_merging(self):
""" """
Merge split operation into one. Merge split operation into one.


+ 4
- 14
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -48,18 +48,8 @@ class TFGraphParser(GraphParser):
log.error(str(error)) log.error(str(error))
raise error raise error


try:
model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
) # need pass more args

except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
)
return model return model

Loading…
Cancel
Save