@@ -13,19 +13,28 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import os
import re
import warnings
from copy import deepcopy
from importlib import import_module
from typing import Dict, NoReturn
import numpy as np
from mindinsight.conf import settings
from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .torch_utils import set_opset_version
from ..common.utils import fetch_output_from_onnx_model
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME
from ..constant import LEFT_BUCKET, RIGHT_BUCKET
from ...common.exceptions import ModelNotSupportError
NONE_SCOPE_OP = {
"onnx::Add": "Add",
@@ -37,7 +46,39 @@ NONE_SCOPE_OP = {
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean"
"onnx::ReduceMean": "ReduceMean",
"onnx::Resize": "Resize",
"onnx::Pad": "Pad"
}
CONSTANT_NODES_PATTERN = {
"onnx::Resize": [
'onnx::Concat',
'onnx::Slice',
'onnx::Cast',
'onnx::Concat',
'onnx::Unsqueeze',
'onnx::Floor',
'onnx::Mul',
'onnx::Cast',
'onnx::Gather',
'onnx::Shape'
],
"onnx::Pad": [
'onnx::Cast',
'onnx::Concat',
'onnx::ConstantOfShape',
'onnx::Sub',
'onnx::Mul',
'onnx::Div',
'onnx::Gather',
'onnx::Shape',
'onnx::Unsqueeze',
'onnx::Slice',
'onnx::Reshape',
'onnx::Transpose'
],
"onnx::Constant": list()
}
@@ -129,6 +170,15 @@ class PyTorchGraph(Graph):
from .torch_utils import unique_state_dict
self._params_dict = unique_state_dict(model)
self._original_shape = list()
self._nodes = list()
self._constant_nodes = list()
self._dynamic_nodes = list()
self._has_eliminated_nodes = False
self._file_graph_onnx = os.path.join(
settings.WORKSPACE, 'log/mindconverter/'
)
self.build(sample_shape)
@staticmethod
@@ -195,17 +245,32 @@ class PyTorchGraph(Graph):
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer
warnings.simplefilter("ignore")
batched_sample = create_autograd_variable(torch.rand(*input_shape))
try:
# Assign execution mode to eval.
self.model.eval()
with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph
try:
# Assign execution mode to eval.
self.model.eval()
with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph
except RuntimeError:
# Assign execution mode to eval.
self.model.eval()
with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
set_opset_version(ONNX_OPSET_VERSION)
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph
except RuntimeError as error:
log.error(str(error))
log.exception(error)
@@ -220,14 +285,21 @@ class PyTorchGraph(Graph):
"""
self._check_input_shape(input_shape)
self._original_shape = input_shape
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())
self._nodes = nodes
scope_name_dict = dict()
self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes)
for node in nodes:
output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs()))
if output_name in self._dynamic_nodes:
continue
node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
@@ -237,16 +309,33 @@ class PyTorchGraph(Graph):
weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)]', node.scopeName())
)
node_weight = {}
if self._constant_nodes:
node_weight = self._replace_constant_node(node)
else:
node_weight = {}
for scope, weight in self._params_dict.items():
split_scope = scope.split('.')
if '.'.join(split_scope[:-1]) == weight_scope:
node_weight[split_scope[-1]] = weight
if not node_weight and node.kind() == 'onnx::Conv':
weight_names = list(self._params_dict.keys())
node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_input_names:
if int(node_input_name) > len(weight_names):
continue
weight = self._params_dict[weight_names[int(node_input_name) - 1]]
node_weight[weight_names[int(node_input_name) - 1]] = weight
self._shape_dict[node_name] = output_shape
self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight)
self._nodes_record[node_name] = node_name
for node_input in list(node.inputs()):
if self._extract_input_name(node_input) in self._constant_nodes:
continue
# Connect input node and src node.
nd_id = PyTorchGraph.get_node_id(node_input.node())
nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \
@@ -263,6 +352,165 @@ class PyTorchGraph(Graph):
super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)
@staticmethod
def _extract_node_name(node):
"""Extract node name for node."""
result = re.match(r"\d+", str(node))
if result:
return result.group(0)
return None
@staticmethod
def _extract_input_name(node_input):
"""Extract node input name from node input."""
node_input_name = str(node_input).split('defined in')[0].strip()
return node_input_name
def _get_constant_nodes(self, nodes):
"""
Get constant nodes to be eliminated.
Args:
nodes (Nodes): Nodes in torch._C.Graph.
Returns:
Union(dict, list), output of constant_input_node_name and dynamic nodes name.
"""
constant_input_nodes = list()
dynamic_nodes = list()
for node in nodes:
if node.kind() == 'onnx::Resize':
self._has_eliminated_nodes = True
constant_input_node, dynamic_node = self._generate_inputs_of(node)
constant_input_nodes += constant_input_node
dynamic_nodes += dynamic_node
outputs = dict()
if self._has_eliminated_nodes:
torch = import_module('torch')
device_target = 'cuda' if torch.cuda.is_available() else 'cpu'
dump_input = torch.randn(*self._original_shape, device=device_target)
temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx,
'.graph_onnx.onnx'))
symbolic_helper = import_module('torch.onnx.symbolic_helper')
export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version')
try:
torch.onnx.export(self.model.to(device_target), dump_input,
temp_onnx_path, opset_version=export_onnx_opset_version)
outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape)
finally:
if os.path.exists(temp_onnx_path):
os.remove(temp_onnx_path)
return outputs, dynamic_nodes
def _generate_inputs_of(self, node):
"""
Generate inputs of certain node.
Args:
node (Node): Node of torch._C.Graph.
"""
pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None)
constant_input_nodes = list()
dynamic_nodes = list()
if not isinstance(pattern_op_lst, list):
return constant_input_nodes, dynamic_nodes
if not pattern_op_lst:
dynamic_nodes += self.get_node_id(node)
return constant_input_nodes, dynamic_nodes
node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_inputs_name:
node_name_path = self._search_node_path(node_input_name, pattern_op_lst)
if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape':
constant_input_nodes.append(node_input_name)
dynamic_nodes += node_name_path
return constant_input_nodes, dynamic_nodes
def _search_node_path(self, node_name, pattern_op_lst):
"""
Search node path based on pattern_op_list.
Args:
node_name (str): Node name.
pattern_op_lst (list): Pattern list of certain operator.
Returns:
list[str]: node names in pattern.
"""
node_type_lst = list()
node_name_lst = list()
node = self._get_node_from_graph(node_name)
if node_name == MODEL_INPUT_NAME:
return node_name_lst
if node.kind() not in pattern_op_lst:
return node_name_lst
node_type_lst.append(node.kind())
node_name_lst.append(node_name)
node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_inputs_name:
node_name_lst += self._search_node_path(node_input_name, pattern_op_lst)
return node_name_lst
def _get_node_from_graph(self, node_name):
"""Get torch._C.Node from torch._C.Graph."""
for idx, node in enumerate(self._nodes):
node_id = ', '.join(self.get_node_id(node))
if node_id == node_name:
return self._nodes[idx]
return None
@staticmethod
def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape):
"""
Infer onnx model to get outputs of inner nodes.
Args:
file_graph_onnx (str): File path of onnx.
infer_outputs (list): Outputs for infer.
infer_inputs_shape (list): Input shape for infer.
"""
onnx = import_module('onnx')
tensor_proto = getattr(onnx, 'TensorProto')
onnx_model = onnx.load(file_graph_onnx)
for onnx_node in onnx_model.graph.node:
if set(onnx_node.output).issubset(set(infer_outputs)):
onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output])
input_onnx = onnx_model.graph.input[0]
node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type)
if node_type != 'FLOAT':
raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. "
f"Please report issue to us if extra input type is needed.")
input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}
outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs)
return outputs
def _replace_constant_node(self, node):
"""Replace constant node."""
node_weight = dict()
for node_input in list(node.inputs()):
node_input_name = self._extract_input_name(node_input)
if node_input_name in self._constant_nodes:
node_weight[node_input_name] = self._constant_nodes[node_input_name]
return node_weight
def _collect_ipt_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.
@@ -338,7 +586,7 @@ class PyTorchGraph(Graph):
def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt_opt():
if self._check_multi_ipt_opt() or self._has_eliminated_nodes :
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes