Browse Source

!972 Optimize input_shape check & Add function to process multi-input and multi-output in pytorch.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0302369b3b
4 changed files with 84 additions and 17 deletions
  1. +2
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  3. +77
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  4. +4
    -14
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 2
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -50,6 +50,8 @@ ARGUMENT_LENGTH_LIMIT = 512

EXPECTED_NUMBER = 1

MIN_SCOPE_LENGTH = 2


@unique
class CodeFormatConfig(Enum):


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -520,7 +520,7 @@ class GraphNode(abc.ABC):
if input_type == InputType.TENSOR.value:
ipt_args_settings_in_construct = ipt_args_in_construct
elif input_type == InputType.LIST.value:
ipt_args_settings_in_construct = f"({ipt_args_in_construct})"
ipt_args_settings_in_construct = f"({ipt_args_in_construct},)"
else:
raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.")
else:


+ 77
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Define PyTorch graph."""
import re
from copy import deepcopy
from typing import Dict, NoReturn

from mindinsight.mindconverter.common.log import logger as log
@@ -22,7 +23,8 @@ from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH
from ..constant import LEFT_BUCKET, RIGHT_BUCKET

NONE_SCOPE_OP = {
@@ -206,6 +208,8 @@ class PyTorchGraph(Graph):
)
self.build_connection(node_input_name, node_name)

self._unmerge_multi_ipt_opt_script()

super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)

@@ -227,13 +231,84 @@ class PyTorchGraph(Graph):
input_node.set_successor_nodes(node_name)
self._shape_dict[ipt_nd_name] = input_node.output_shape

if not self._shape_dict[node_name]:
self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE

ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp))
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)

self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape

def _generate_module(self):
"""Generate modules."""
module_dict = dict()
for node_key, _ in self._nodes_collection.items():
node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE)
if len(node_key_in_scope) < MIN_SCOPE_LENGTH:
continue

for idx in range(1, len(node_key_in_scope)):
node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx])
node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1])
if not module_dict.get(node_key_module, None):
module_dict[node_key_module] = {node_name}
else:
module_dict[node_key_module].add(node_name)

return module_dict

def _check_multi_ipt(self):
"""Check whether multi-input exists."""
module_dict = self._generate_module()
for _, nodes_per_module in module_dict.items():
prcs_nodes_out_from_module = set()
for node_name in nodes_per_module:
node = self._nodes_collection.get(node_name, None)
if node:
prcs_nodes = node.precursor_nodes
else:
continue

for prcs_node in prcs_nodes:
if prcs_node not in nodes_per_module:
prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1])
if prcs_node_module not in nodes_per_module:
prcs_nodes_out_from_module.add(prcs_node)

if len(prcs_nodes_out_from_module) > 1:
return True

return False

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt():
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes

node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0],
prsc_node.split(SEPARATOR_IN_SCOPE)[-1]))
for prsc_node in deepcopy(prsc_nodes)]
node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0],
scsr_node.split(SEPARATOR_IN_SCOPE)[-1]))
for scsr_node in deepcopy(scsr_nodes)]

reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._nodes_collection[node_key]
self._nodes_collection[reduce_node_key] = node_inst

for node_key, shape in deepcopy(self._shape_dict).items():
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0],
node_key.split(SEPARATOR_IN_SCOPE)[-1]))

del self._shape_dict[node_key]
self._shape_dict[reduce_node_key] = shape

def sub_graph_merging(self):
"""
Merge split operation into one.


+ 4
- 14
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -48,18 +48,8 @@ class TFGraphParser(GraphParser):
log.error(str(error))
raise error

try:
model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
) # need pass more args

except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
raise error from None

model = convert_tf_graph_to_onnx(model_path,
model_inputs=tf_input_nodes,
model_outputs=tf_output_nodes,
)
return model

Loading…
Cancel
Save