diff --git a/mindinsight/mindconverter/README.md b/mindinsight/mindconverter/README.md index 7ebb93b9..5273364b 100644 --- a/mindinsight/mindconverter/README.md +++ b/mindinsight/mindconverter/README.md @@ -155,7 +155,8 @@ Supported models list (Models in below table have been tested based on PyTorch 1 | DenseNet121/169/201 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | | DenseNet161 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | / | | | NASNetMobile/Large | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | -| EfficientNetB0~B7 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | Use TensorFlow 2.3 to export model and convert | +| EfficientNetB0~B7 | [Link](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5Link](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | +| Unet | [Link](https://github.com/milesial/Pytorch-UNet) | [Link](https://github.com/zhixuhao/unet) | Due to Operator `ResizeBilinear` not achieved on GPU device, Operator `ResizeBilinear` should be replaced by operator `ResizeNearest`, while running in GPU device | ## Example diff --git a/mindinsight/mindconverter/README_CN.md b/mindinsight/mindconverter/README_CN.md index 0f8821fa..265a7633 100644 --- a/mindinsight/mindconverter/README_CN.md +++ b/mindinsight/mindconverter/README_CN.md @@ -154,7 +154,8 @@ MindConverter提供两种技术方案,以应对不同脚本迁移场景: | DenseNet121/169/201 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | | DenseNet161 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | 暂未测试 | | | NASNetMobile/Large | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | -| EfficientNetB0~B7 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | 模型使用TensorFlow 2.3导出、转换 | +| EfficientNetB0~B7 | [脚本链接](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5脚本链接](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | +| Unet | [脚本链接](https://github.com/milesial/Pytorch-UNet) | [脚本链接](https://github.com/zhixuhao/unet) | 由于算子`ResizeBilinear`在GPU上未实现,所以当运行在GPU设备上时,算子`ResizeBilinear`需要被替换为算子`ResizeNearest` | ## 使用示例 diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index af9b9ed1..0dafb94f 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -53,8 +53,13 @@ EXPECTED_NUMBER = 1 MIN_SCOPE_LENGTH = 2 +ONNX_OPSET_VERSION = 11 + +MODEL_INPUT_NAME = 'input.1' + NO_CONVERTED_OPERATORS = [ - "onnx::Constant" + "onnx::Constant", + "Constant" ] diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index aeb9099a..31b70d26 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -14,7 +14,6 @@ # ============================================================================== """Graph based scripts converter workflow.""" import os -import re import argparse import sys from importlib import import_module @@ -65,10 +64,28 @@ def torch_installation_validation(func): def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None): # Check whether pytorch is installed. - if not find_spec("torch"): - error = RuntimeIntegrityError("PyTorch is required when using graph based " - "scripts converter, and PyTorch version must " - "be consisted with model generation runtime.") + if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"): + error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) " + f"are required when using graph based " + f"scripts converter, and PyTorch version must " + f"be consisted with model generation runtime.") + log.error(error) + log_console.error("\n") + log_console.error(str(error)) + log_console.error("\n") + sys.exit(0) + + onnx = import_module("onnx") + ort = import_module("onnxruntime") + + if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ + or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER): + error = RuntimeIntegrityError( + f"onnx(>={ONNX_MIN_VER}) and " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"based scripts converter for Pytorch conversion." + ) log.error(error) log_console.error("\n") log_console.error(str(error)) @@ -154,7 +171,8 @@ def _extract_model_name(model_path): str: Name of Converted model. """ - model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1] + base_path = os.path.basename(model_path) + model_name = '.'.join(base_path.split('.')[:-1]) return model_name diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index fc45a6a9..3cc84216 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -13,19 +13,28 @@ # limitations under the License. # ============================================================================== """Define PyTorch graph.""" +import os import re +import warnings from copy import deepcopy +from importlib import import_module from typing import Dict, NoReturn +import numpy as np + +from mindinsight.conf import settings from mindinsight.mindconverter.common.log import logger as log from .base import Graph from .input_node import InputNode from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_parser import PyTorchGraphParser +from .torch_utils import set_opset_version +from ..common.utils import fetch_output_from_onnx_model from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ - MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT + MIN_SCOPE_LENGTH, SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT, ONNX_OPSET_VERSION, MODEL_INPUT_NAME from ..constant import LEFT_BUCKET, RIGHT_BUCKET +from ...common.exceptions import ModelNotSupportError NONE_SCOPE_OP = { "onnx::Add": "Add", @@ -37,7 +46,39 @@ NONE_SCOPE_OP = { "onnx::Reshape": "Reshape", "onnx::Transpose": "Transpose", "onnx::Constant": "Constant", - "onnx::ReduceMean": "ReduceMean" + "onnx::ReduceMean": "ReduceMean", + "onnx::Resize": "Resize", + "onnx::Pad": "Pad" +} + +CONSTANT_NODES_PATTERN = { + "onnx::Resize": [ + 'onnx::Concat', + 'onnx::Slice', + 'onnx::Cast', + 'onnx::Concat', + 'onnx::Unsqueeze', + 'onnx::Floor', + 'onnx::Mul', + 'onnx::Cast', + 'onnx::Gather', + 'onnx::Shape' + ], + "onnx::Pad": [ + 'onnx::Cast', + 'onnx::Concat', + 'onnx::ConstantOfShape', + 'onnx::Sub', + 'onnx::Mul', + 'onnx::Div', + 'onnx::Gather', + 'onnx::Shape', + 'onnx::Unsqueeze', + 'onnx::Slice', + 'onnx::Reshape', + 'onnx::Transpose' + ], + "onnx::Constant": list() } @@ -129,6 +170,15 @@ class PyTorchGraph(Graph): from .torch_utils import unique_state_dict self._params_dict = unique_state_dict(model) + self._original_shape = list() + self._nodes = list() + self._constant_nodes = list() + self._dynamic_nodes = list() + self._has_eliminated_nodes = False + self._file_graph_onnx = os.path.join( + settings.WORKSPACE, 'log/mindconverter/' + ) + self.build(sample_shape) @staticmethod @@ -195,17 +245,32 @@ class PyTorchGraph(Graph): from .torch_utils import create_autograd_variable from .torch_utils import onnx_tracer + warnings.simplefilter("ignore") + batched_sample = create_autograd_variable(torch.rand(*input_shape)) try: - # Assign execution mode to eval. - self.model.eval() - - with OverloadTorchModuleTemporarily() as _: - # In pytorch higher version, trace function has a known. - graph = onnx_tracer(self.model, batched_sample, - OperatorExportTypes.ONNX) - return graph + try: + # Assign execution mode to eval. + self.model.eval() + + with OverloadTorchModuleTemporarily() as _: + # In pytorch higher version, trace function has a known. + graph = onnx_tracer(self.model, batched_sample, + OperatorExportTypes.ONNX) + return graph + + except RuntimeError: + # Assign execution mode to eval. + self.model.eval() + + with OverloadTorchModuleTemporarily() as _: + # In pytorch higher version, trace function has a known. + set_opset_version(ONNX_OPSET_VERSION) + graph = onnx_tracer(self.model, batched_sample, + OperatorExportTypes.ONNX) + return graph + except RuntimeError as error: log.error(str(error)) log.exception(error) @@ -220,14 +285,21 @@ class PyTorchGraph(Graph): """ self._check_input_shape(input_shape) - + self._original_shape = input_shape feed_forward_ipt_shape = tuple(input_shape) graph = self._trace_torch_graph(feed_forward_ipt_shape) nodes = list(graph.nodes()) + self._nodes = nodes scope_name_dict = dict() + self._constant_nodes, self._dynamic_nodes = self._get_constant_nodes(nodes) + for node in nodes: + output_name = ', '.join(list(self._extract_node_name(output) for output in node.outputs())) + if output_name in self._dynamic_nodes: + continue + node_name = normalize_scope_name(node, scope_name_dict) scope_name_dict[node_name.split(SEPARATOR_BTW_NAME_AND_ID)[-1]] \ = list(node_name.split(SEPARATOR_BTW_NAME_AND_ID)[0].split(SEPARATOR_IN_SCOPE)) @@ -237,16 +309,33 @@ class PyTorchGraph(Graph): weight_scope = '.'.join( re.findall(r'\[([\w\d.]+)]', node.scopeName()) ) - node_weight = {} + + if self._constant_nodes: + node_weight = self._replace_constant_node(node) + else: + node_weight = {} + for scope, weight in self._params_dict.items(): split_scope = scope.split('.') if '.'.join(split_scope[:-1]) == weight_scope: node_weight[split_scope[-1]] = weight + + if not node_weight and node.kind() == 'onnx::Conv': + weight_names = list(self._params_dict.keys()) + node_input_names = [self._extract_input_name(node_input) for node_input in node.inputs()] + for node_input_name in node_input_names: + if int(node_input_name) > len(weight_names): + continue + weight = self._params_dict[weight_names[int(node_input_name) - 1]] + node_weight[weight_names[int(node_input_name) - 1]] = weight + self._shape_dict[node_name] = output_shape self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight) self._nodes_record[node_name] = node_name for node_input in list(node.inputs()): + if self._extract_input_name(node_input) in self._constant_nodes: + continue # Connect input node and src node. nd_id = PyTorchGraph.get_node_id(node_input.node()) nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ @@ -263,6 +352,165 @@ class PyTorchGraph(Graph): super(PyTorchGraph, self).build(input_shape=input_shape) self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) + @staticmethod + def _extract_node_name(node): + """Extract node name for node.""" + result = re.match(r"\d+", str(node)) + if result: + return result.group(0) + return None + + @staticmethod + def _extract_input_name(node_input): + """Extract node input name from node input.""" + node_input_name = str(node_input).split('defined in')[0].strip() + return node_input_name + + def _get_constant_nodes(self, nodes): + """ + Get constant nodes to be eliminated. + + Args: + nodes (Nodes): Nodes in torch._C.Graph. + + Returns: + Union(dict, list), output of constant_input_node_name and dynamic nodes name. + """ + constant_input_nodes = list() + dynamic_nodes = list() + for node in nodes: + if node.kind() == 'onnx::Resize': + self._has_eliminated_nodes = True + constant_input_node, dynamic_node = self._generate_inputs_of(node) + constant_input_nodes += constant_input_node + dynamic_nodes += dynamic_node + + outputs = dict() + if self._has_eliminated_nodes: + torch = import_module('torch') + device_target = 'cuda' if torch.cuda.is_available() else 'cpu' + dump_input = torch.randn(*self._original_shape, device=device_target) + temp_onnx_path = os.path.realpath(os.path.join(self._file_graph_onnx, + '.graph_onnx.onnx')) + + symbolic_helper = import_module('torch.onnx.symbolic_helper') + export_onnx_opset_version = getattr(symbolic_helper, '_export_onnx_opset_version') + try: + torch.onnx.export(self.model.to(device_target), dump_input, + temp_onnx_path, opset_version=export_onnx_opset_version) + + outputs = self._onnx_infer(temp_onnx_path, constant_input_nodes, self._original_shape) + finally: + if os.path.exists(temp_onnx_path): + os.remove(temp_onnx_path) + + return outputs, dynamic_nodes + + def _generate_inputs_of(self, node): + """ + Generate inputs of certain node. + + Args: + node (Node): Node of torch._C.Graph. + + """ + pattern_op_lst = CONSTANT_NODES_PATTERN.get(node.kind(), None) + constant_input_nodes = list() + dynamic_nodes = list() + if not isinstance(pattern_op_lst, list): + return constant_input_nodes, dynamic_nodes + if not pattern_op_lst: + dynamic_nodes += self.get_node_id(node) + return constant_input_nodes, dynamic_nodes + + node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] + + for node_input_name in node_inputs_name: + node_name_path = self._search_node_path(node_input_name, pattern_op_lst) + if node_name_path and self._get_node_from_graph(node_name_path[-1]).kind() == 'onnx::Shape': + constant_input_nodes.append(node_input_name) + dynamic_nodes += node_name_path + + return constant_input_nodes, dynamic_nodes + + def _search_node_path(self, node_name, pattern_op_lst): + """ + Search node path based on pattern_op_list. + + Args: + node_name (str): Node name. + pattern_op_lst (list): Pattern list of certain operator. + + Returns: + list[str]: node names in pattern. + """ + node_type_lst = list() + node_name_lst = list() + node = self._get_node_from_graph(node_name) + + if node_name == MODEL_INPUT_NAME: + return node_name_lst + + if node.kind() not in pattern_op_lst: + return node_name_lst + + node_type_lst.append(node.kind()) + node_name_lst.append(node_name) + + node_inputs_name = [self._extract_input_name(node_input) for node_input in node.inputs()] + for node_input_name in node_inputs_name: + node_name_lst += self._search_node_path(node_input_name, pattern_op_lst) + + return node_name_lst + + def _get_node_from_graph(self, node_name): + """Get torch._C.Node from torch._C.Graph.""" + for idx, node in enumerate(self._nodes): + node_id = ', '.join(self.get_node_id(node)) + if node_id == node_name: + return self._nodes[idx] + return None + + @staticmethod + def _onnx_infer(file_graph_onnx, infer_outputs, infer_inputs_shape): + """ + Infer onnx model to get outputs of inner nodes. + + Args: + file_graph_onnx (str): File path of onnx. + infer_outputs (list): Outputs for infer. + infer_inputs_shape (list): Input shape for infer. + + """ + onnx = import_module('onnx') + tensor_proto = getattr(onnx, 'TensorProto') + onnx_model = onnx.load(file_graph_onnx) + + for onnx_node in onnx_model.graph.node: + if set(onnx_node.output).issubset(set(infer_outputs)): + onnx_node.name = ', '.join([f"{output_name}" for output_name in onnx_node.output]) + + input_onnx = onnx_model.graph.input[0] + node_type = tensor_proto.DataType.Name(input_onnx.type.tensor_type.elem_type) + if node_type != 'FLOAT': + raise ModelNotSupportError(f"Input type should be FLOAT32, but got {node_type}. " + f"Please report issue to us if extra input type is needed.") + + input_onnx_name = input_onnx.name + feed_dict = {input_onnx_name: np.random.rand(*infer_inputs_shape).astype(np.float32)} + outputs = fetch_output_from_onnx_model(onnx_model, feed_dict, infer_outputs) + + return outputs + + def _replace_constant_node(self, node): + """Replace constant node.""" + node_weight = dict() + for node_input in list(node.inputs()): + node_input_name = self._extract_input_name(node_input) + if node_input_name in self._constant_nodes: + node_weight[node_input_name] = self._constant_nodes[node_input_name] + return node_weight + def _collect_ipt_shape_of_each_node(self, input_shape): """ Collect input tensor shape of each node. @@ -338,7 +586,7 @@ class PyTorchGraph(Graph): def _unmerge_multi_ipt_opt_script(self): """Unmerge all submodule.""" - if self._check_multi_ipt_opt(): + if self._check_multi_ipt_opt() or self._has_eliminated_nodes: for node_key, node_inst in deepcopy(self._nodes_collection).items(): prsc_nodes = node_inst.precursor_nodes scsr_nodes = node_inst.successor_nodes diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py index 9e5c2494..623d4089 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/torch_utils.py @@ -18,11 +18,14 @@ import importlib from torch.nn import Module from torch.onnx.utils import _trace from torch.onnx.utils import _node_getitem +from torch.onnx.symbolic_helper import _set_opset_version + SCRIPT_METHOD = getattr(importlib.import_module("torch._C"), "ScriptMethod") onnx_tracer = _trace getitem_of_node = _node_getitem +set_opset_version = _set_opset_version def unique_state_dict(model):