Browse Source

!1054 Fix converter for UNet in pytorch.

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
bb8d69f222
6 changed files with 298 additions and 22 deletions
  1. +2
    -1
      mindinsight/mindconverter/README.md
  2. +2
    -1
      mindinsight/mindconverter/README_CN.md
  3. +6
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +24
    -6
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +261
    -13
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  6. +3
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py

+ 2
- 1
mindinsight/mindconverter/README.md View File

@@ -155,7 +155,8 @@ Supported models list (Models in below table have been tested based on PyTorch 1
| DenseNet121/169/201 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | |
| DenseNet161 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | / | |
| NASNetMobile/Large | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | |
| EfficientNetB0~B7 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | Use TensorFlow 2.3 to export model and convert |
| EfficientNetB0~B7 | [Link](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5Link](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | |
| Unet | [Link](https://github.com/milesial/Pytorch-UNet) | [Link](https://github.com/zhixuhao/unet) | Due to Operator `ResizeBilinear` not achieved on GPU device, Operator `ResizeBilinear` should be replaced by operator `ResizeNearest`, while running in GPU device |

## Example



+ 2
- 1
mindinsight/mindconverter/README_CN.md View File

@@ -154,7 +154,8 @@ MindConverter提供两种技术方案,以应对不同脚本迁移场景:
| DenseNet121/169/201 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | |
| DenseNet161 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | 暂未测试 | |
| NASNetMobile/Large | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | |
| EfficientNetB0~B7 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | 模型使用TensorFlow 2.3导出、转换 |
| EfficientNetB0~B7 | [脚本链接](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5脚本链接](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | |
| Unet | [脚本链接](https://github.com/milesial/Pytorch-UNet) | [脚本链接](https://github.com/zhixuhao/unet) | 由于算子`ResizeBilinear`在GPU上未实现,所以当运行在GPU设备上时,算子`ResizeBilinear`需要被替换为算子`ResizeNearest` |

## 使用示例



+ 6
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -53,8 +53,13 @@ EXPECTED_NUMBER = 1

MIN_SCOPE_LENGTH = 2

ONNX_OPSET_VERSION = 11

MODEL_INPUT_NAME = 'input.1'

NO_CONVERTED_OPERATORS = [
"onnx::Constant"
"onnx::Constant",
"Constant"
]




+ 24
- 6
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -14,7 +14,6 @@
# ==============================================================================
"""Graph based scripts converter workflow."""
import os
import re
import argparse
import sys
from importlib import import_module
@@ -65,10 +64,28 @@ def torch_installation_validation(func):
def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
if not find_spec("torch"):
error = RuntimeIntegrityError("PyTorch is required when using graph based "
"scripts converter, and PyTorch version must "
"be consisted with model generation runtime.")
if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
f"are required when using graph based "
f"scripts converter, and PyTorch version must "
f"be consisted with model generation runtime.")
log.error(error)
log_console.error("\n")
log_console.error(str(error))
log_console.error("\n")
sys.exit(0)

onnx = import_module("onnx")
ort = import_module("onnxruntime")

if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
error = RuntimeIntegrityError(
f"onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"based scripts converter for Pytorch conversion."
)
log.error(error)
log_console.error("\n")
log_console.error(str(error))
@@ -154,7 +171,8 @@ def _extract_model_name(model_path):
str: Name of Converted model.
"""

model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1]
base_path = os.path.basename(model_path)
model_name = '.'.join(base_path.split('.')[:-1])
return model_name




+ 261
- 13
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -13,19 +13,28 @@
# limitations under the License.
# ==============================================================================
"""Define PyTorch graph."""
import os
import re
import warnings
from copy import deepcopy
from importlib import import_module
from typing import Dict, NoReturn

import numpy as np

from mindinsight.conf import settings
from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .torch_utils import set_opset_version
from ..common.utils import fetch_output_from_onnx_model

from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT
MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME
from ..constant import LEFT_BUCKET, RIGHT_BUCKET
from ...common.exceptions import ModelNotSupportError

NONE_SCOPE_OP = {
"onnx::Add": "Add",
@@ -37,7 +46,39 @@ NONE_SCOPE_OP = {
"onnx::Reshape": "Reshape",
"onnx::Transpose": "Transpose",
"onnx::Constant": "Constant",
"onnx::ReduceMean": "ReduceMean"
"onnx::ReduceMean": "ReduceMean",
"onnx::Resize": "Resize",
"onnx::Pad": "Pad"
}

CONSTANT_NODES_PATTERN = {
"onnx::Resize": [
'onnx::Concat',
'onnx::Slice',
'onnx::Cast',
'onnx::Concat',
'onnx::Unsqueeze',
'onnx::Floor',
'onnx::Mul',
'onnx::Cast',
'onnx::Gather',
'onnx::Shape'
],
"onnx::Pad": [
'onnx::Cast',
'onnx::Concat',
'onnx::ConstantOfShape',
'onnx::Sub',
'onnx::Mul',
'onnx::Div',
'onnx::Gather',
'onnx::Shape',
'onnx::Unsqueeze',
'onnx::Slice',
'onnx::Reshape',
'onnx::Transpose'
],
"onnx::Constant": list()
}


@@ -129,6 +170,15 @@ class PyTorchGraph(Graph):
from .torch_utils import unique_state_dict

self._params_dict = unique_state_dict(model)
self._original_shape = list()
self._nodes = list()
self._constant_nodes = list()
self._dynamic_nodes = list()
self._has_eliminated_nodes = False
self._file_graph_onnx = os.path.join(
settings.WORKSPACE, 'log/mindconverter/'
)

self.build(sample_shape)

@staticmethod
@@ -195,17 +245,32 @@ class PyTorchGraph(Graph):
from .torch_utils import create_autograd_variable
from .torch_utils import onnx_tracer

warnings.simplefilter("ignore")

batched_sample = create_autograd_variable(torch.rand(*input_shape))

try:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph
try:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError:
# Assign execution mode to eval.
self.model.eval()

with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
set_opset_version(ONNX_OPSET_VERSION)
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
return graph

except RuntimeError as error:
log.error(str(error))
log.exception(error)
@@ -220,14 +285,21 @@ class PyTorchGraph(Graph):

"""
self._check_input_shape(input_shape)
self._original_shape = input_shape
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())
self._nodes = nodes

scope_name_dict = dict()

self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes)

for node in nodes:
output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs()))
if output_name in self._dynamic_nodes:
continue

node_name = normalize_scope_name(node, scope_name_dict)
scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \
= list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE))
@@ -237,16 +309,33 @@ class PyTorchGraph(Graph):
weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)]', node.scopeName())
)
node_weight = {}

if self._constant_nodes:
node_weight = self._replace_constant_node(node)
else:
node_weight = {}

for scope, weight in self._params_dict.items():
split_scope = scope.split('.')
if '.'.join(split_scope[:-1]) == weight_scope:
node_weight[split_scope[-1]] = weight

if not node_weight and node.kind() == 'onnx::Conv':
weight_names = list(self._params_dict.keys())
node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_input_names:
if int(node_input_name) > len(weight_names):
continue
weight = self._params_dict[weight_names[int(node_input_name) - 1]]
node_weight[weight_names[int(node_input_name) - 1]] = weight

self._shape_dict[node_name] = output_shape
self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight)
self._nodes_record[node_name] = node_name

for node_input in list(node.inputs()):
if self._extract_input_name(node_input) in self._constant_nodes:
continue
# Connect input node and src node.
nd_id = PyTorchGraph.get_node_id(node_input.node())
nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \
@@ -263,6 +352,165 @@ class PyTorchGraph(Graph):
super(PyTorchGraph, self).build(input_shape=input_shape)
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape)

@staticmethod
def _extract_node_name(node):
"""Extract node name for node."""
result = re.match(r"\d+", str(node))
if result:
return result.group(0)
return None

@staticmethod
def _extract_input_name(node_input):
"""Extract node input name from node input."""
node_input_name = str(node_input).split('defined in')[0].strip()
return node_input_name

def _get_constant_nodes(self, nodes):
"""
Get constant nodes to be eliminated.

Args:
nodes (Nodes): Nodes in torch._C.Graph.

Returns:
Union(dict, list), output of constant_input_node_name and dynamic nodes name.
"""
constant_input_nodes = list()
dynamic_nodes = list()
for node in nodes:
if node.kind() == 'onnx::Resize':
self._has_eliminated_nodes = True
constant_input_node, dynamic_node = self._generate_inputs_of(node)
constant_input_nodes += constant_input_node
dynamic_nodes += dynamic_node

outputs = dict()
if self._has_eliminated_nodes:
torch = import_module('torch')
device_target = 'cuda' if torch.cuda.is_available() else 'cpu'
dump_input = torch.randn(*self._original_shape, device=device_target)
temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx,
'.graph_onnx.onnx'))

symbolic_helper = import_module('torch.onnx.symbolic_helper')
export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version')
try:
torch.onnx.export(self.model.to(device_target), dump_input,
temp_onnx_path, opset_version=export_onnx_opset_version)

outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape)
finally:
if os.path.exists(temp_onnx_path):
os.remove(temp_onnx_path)

return outputs, dynamic_nodes

def _generate_inputs_of(self, node):
"""
Generate inputs of certain node.

Args:
node (Node): Node of torch._C.Graph.

"""
pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None)
constant_input_nodes = list()
dynamic_nodes = list()
if not isinstance(pattern_op_lst, list):
return constant_input_nodes, dynamic_nodes
if not pattern_op_lst:
dynamic_nodes += self.get_node_id(node)
return constant_input_nodes, dynamic_nodes

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]

for node_input_name in node_inputs_name:
node_name_path = self._search_node_path(node_input_name, pattern_op_lst)
if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape':
constant_input_nodes.append(node_input_name)
dynamic_nodes += node_name_path

return constant_input_nodes, dynamic_nodes

def _search_node_path(self, node_name, pattern_op_lst):
"""
Search node path based on pattern_op_list.

Args:
node_name (str): Node name.
pattern_op_lst (list): Pattern list of certain operator.

Returns:
list[str]: node names in pattern.
"""
node_type_lst = list()
node_name_lst = list()
node = self._get_node_from_graph(node_name)

if node_name == MODEL_INPUT_NAME:
return node_name_lst

if node.kind() not in pattern_op_lst:
return node_name_lst

node_type_lst.append(node.kind())
node_name_lst.append(node_name)

node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()]
for node_input_name in node_inputs_name:
node_name_lst += self._search_node_path(node_input_name, pattern_op_lst)

return node_name_lst

def _get_node_from_graph(self, node_name):
"""Get torch._C.Node from torch._C.Graph."""
for idx, node in enumerate(self._nodes):
node_id = ', '.join(self.get_node_id(node))
if node_id == node_name:
return self._nodes[idx]
return None

@staticmethod
def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape):
"""
Infer onnx model to get outputs of inner nodes.

Args:
file_graph_onnx (str): File path of onnx.
infer_outputs (list): Outputs for infer.
infer_inputs_shape (list): Input shape for infer.

"""
onnx = import_module('onnx')
tensor_proto = getattr(onnx, 'TensorProto')
onnx_model = onnx.load(file_graph_onnx)

for onnx_node in onnx_model.graph.node:
if set(onnx_node.output).issubset(set(infer_outputs)):
onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output])

input_onnx = onnx_model.graph.input[0]
node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type)
if node_type != 'FLOAT':
raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. "
f"Please report issue to us if extra input type is needed.")

input_onnx_name = input_onnx.name
feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)}
outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs)

return outputs

def _replace_constant_node(self, node):
"""Replace constant node."""
node_weight = dict()
for node_input in list(node.inputs()):
node_input_name = self._extract_input_name(node_input)
if node_input_name in self._constant_nodes:
node_weight[node_input_name] = self._constant_nodes[node_input_name]
return node_weight

def _collect_ipt_shape_of_each_node(self, input_shape):
"""
Collect input tensor shape of each node.
@@ -338,7 +586,7 @@ class PyTorchGraph(Graph):

def _unmerge_multi_ipt_opt_script(self):
"""Unmerge all submodule."""
if self._check_multi_ipt_opt():
if self._check_multi_ipt_opt() or self._has_eliminated_nodes:
for node_key, node_inst in deepcopy(self._nodes_collection).items():
prsc_nodes = node_inst.precursor_nodes
scsr_nodes = node_inst.successor_nodes


+ 3
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py View File

@@ -18,11 +18,14 @@ import importlib
from torch.nn import Module
from torch.onnx.utils import _trace
from torch.onnx.utils import _node_getitem
from torch.onnx.symbolic_helper import _set_opset_version


SCRIPT_METHOD = getattr(importlib.import_module("torch._C"),
"ScriptMethod")
onnx_tracer = _trace
getitem_of_node = _node_getitem
set_opset_version = _set_opset_version


def unique_state_dict(model):


Loading…
Cancel
Save