|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Graph based scripts converter workflow."""
- import multiprocessing as mp
- import os
- import re
- import sys
- from typing import List
- from importlib import import_module
- from importlib.util import find_spec
- from functools import partial
-
- from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
- from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
- save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
- from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
- ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER, TORCH_MIN_VER
- from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
- from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
- from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
- from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
- BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError, \
- BadParamError
- from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
-
- check_common_dependency_integrity = partial(check_dependency_integrity,
- "onnx", "onnxruntime", "onnxoptimizer")
-
-
- def onnx_lib_version_satisfied():
- """Check onnx libs version whether is satisfied."""
- onnx = import_module("onnx")
- ort = import_module("onnxruntime")
- optimizer = import_module("onnxoptimizer.version")
- if not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
- log_console.warning("onnxruntime's version should be greater than %s, however current version is %s.",
- ONNXRUNTIME_MIN_VER, ort.__version__)
-
- if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
- or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER):
- return False
- return True
-
-
- def _print_error(err):
- """Print error to stdout and record it."""
- log.error(err)
- log_console.error("\n")
- log_console.error(str(err))
- log_console.error("\n")
-
-
- def torch_version_satisfied(output_queue):
- """Check Torch version whether is satisfied."""
- satisfied = False
- pattern = r"\d+\.\d+\.\d+"
- torch_version = re.findall(pattern, getattr(import_module('torch'), "__version__"))
- if torch_version:
- satisfied = lib_version_satisfied(torch_version[0], TORCH_MIN_VER)
- output_queue.put(satisfied)
-
-
- def torch_installation_validation(func):
- """
- Validate args of func.
-
- Args:
- func (type): Function.
-
- Returns:
- type, inner function.
- """
-
- def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
- output_folder: str, report_folder: str = None):
- # Check whether pytorch is installed.
- error_info = None
- torch_version_validation = False
- if graph_path.endswith('.onnx'):
- if not onnx_satisfied() or not check_common_dependency_integrity():
- error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
- f"are required when using graph based scripts converter."
- else:
- if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity():
- error_info = \
- f"{get_third_part_lib_validation_error_info(['torch', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " \
- f"are required when using graph based scripts converter, and PyTorch version must " \
- f"be consisted with model generation runtime."
-
- output_queue = mp.Queue()
- process = mp.Process(target=torch_version_satisfied, args=(output_queue,))
- process.start()
- torch_version_validation = output_queue.get()
- process.join()
- if error_info:
- _print_error(RuntimeIntegrityError(error_info))
- sys.exit(0)
-
- if (not torch_version_validation and not graph_path.endswith('.onnx')) or not onnx_lib_version_satisfied():
- lib_check_list = ['onnx', 'onnxruntime', 'onnxoptimizer']
- if not graph_path.endswith('.onnx'):
- lib_check_list.insert(0, 'torch')
- error = RuntimeIntegrityError(
- f"{get_third_part_lib_validation_error_info(lib_check_list)} "
- f"are required when using graph based scripts converter."
- )
- _print_error(error)
- sys.exit(0)
-
- func(graph_path=graph_path,
- input_nodes=input_nodes, output_nodes=output_nodes,
- output_folder=output_folder, report_folder=report_folder)
-
- return _f
-
-
- def _check_tf_installation():
- """
- Check whether TensorFlow was installed.
-
- Returns:
- bool, true or false.
- """
- return find_spec("tensorflow") or find_spec("tensorflow-gpu")
-
-
- def tf_installation_validation(func):
- """
- Validate args of func.
-
- Args:
- func (type): Function.
-
- Returns:
- type, inner function.
- """
-
- def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None,
- input_nodes: str = None, output_nodes: str = None):
- not_integral_error = RuntimeIntegrityError(
- f"TensorFlow, "
- f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
- f"are required when using graph based scripts converter for TensorFlow conversion."
- )
- # Check whether tensorflow is installed.
- if not _check_tf_installation() or not onnx_satisfied():
- _print_error(not_integral_error)
- sys.exit(0)
-
- if not any([check_common_dependency_integrity("tensorflow"),
- check_common_dependency_integrity("tensorflow-gpu")]):
- _print_error(not_integral_error)
- sys.exit(0)
-
- tf2onnx = import_module("tf2onnx")
-
- if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
- or not onnx_lib_version_satisfied():
- _print_error(not_integral_error)
- sys.exit(0)
-
- func(graph_path=graph_path, sample_shape=sample_shape,
- output_folder=output_folder, report_folder=report_folder,
- input_nodes=input_nodes, output_nodes=output_nodes)
-
- return _f
-
-
- def _extract_model_name(model_path):
- """
- Extract model name from model path.
-
- Args:
- model_path (str): Path of Converted model.
-
- Returns:
- str, name of Converted model.
- """
-
- base_path = os.path.basename(model_path)
- model_name = '.'.join(base_path.split('.')[:-1])
- return model_name
-
-
- @torch_installation_validation
- @GraphInitError.uniform_catcher()
- @TreeCreationError.uniform_catcher()
- @SourceFilesSaveError.uniform_catcher()
- @GeneratorError.uniform_catcher()
- def graph_based_converter_pytorch_to_ms(graph_path: str,
- input_nodes: dict, output_nodes: List[str],
- output_folder: str, report_folder: str = None):
- """
- PyTorch to MindSpore based on Graph.
-
- Args:
- graph_path (str): Graph file path.
- input_nodes (dict): Input node(s) of the model.
- output_nodes (list[str]): Output node(s) of the model.
- output_folder (str): Output folder.
- report_folder (str): Report output folder path.
- """
- graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
- generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
- model_name = _extract_model_name(graph_path)
- code_fragments = generator_inst.generate()
- save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
- # Release global context.
- GlobalContext.release()
-
-
- @tf_installation_validation
- @GraphInitError.uniform_catcher()
- @TfRuntimeError.uniform_catcher()
- @TreeCreationError.uniform_catcher()
- @SourceFilesSaveError.uniform_catcher()
- @GeneratorError.uniform_catcher()
- def graph_based_converter_tf_to_ms(graph_path: str,
- input_nodes: dict, output_nodes: List[str],
- output_folder: str, report_folder: str = None):
- """
- Tensorflow to MindSpore based on Graph.
-
- Args:
- graph_path (str): Graph file path.
- input_nodes (dict): Input node(s) of the model.
- output_nodes (list[str]): Output node(s) of the model.
- output_folder (str): Output folder.
- report_folder (str): Report output folder path.
- """
- # Close unnecessary log.
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
-
- graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
- generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
- model_name = _extract_model_name(graph_path)
- code_fragments = generator_inst.generate()
- save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
- # Release global context.
- GlobalContext.release()
-
-
- @BaseConverterError.uniform_catcher()
- def main_graph_base_converter(file_config):
- """
- The entrance for converter, script files will be converted.
-
- Args:
- file_config (dict): The config of file which to convert.
- """
- graph_path = file_config['model_file']
- frame_type = get_framework_type(graph_path)
- if not file_config.get("shape"):
- raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")
-
- if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \
- file_config.get("shape") and len(file_config.get("shape", ())) == 1:
- file_config['input_nodes'] = ["input.1"]
-
- if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len(
- set(file_config.get("input_nodes", []))):
- raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
- "and no redundant node in `--input_nodes`.")
-
- input_nodes = dict()
- for shape, node in zip(file_config['shape'], file_config['input_nodes']):
- input_nodes[node] = shape
-
- if frame_type == FrameworkType.PYTORCH.value:
- if graph_path.endswith('.onnx'):
- check_params = ['input_nodes', 'output_nodes']
- check_params_exist(check_params, file_config)
- graph_based_converter_pytorch_to_ms(graph_path=graph_path,
- input_nodes=input_nodes,
- output_nodes=file_config['output_nodes'],
- output_folder=file_config['outfile_dir'],
- report_folder=file_config['report_dir'])
- else:
- graph_based_converter_pytorch_to_ms(graph_path=graph_path,
- input_nodes=input_nodes,
- output_nodes=[],
- output_folder=file_config['outfile_dir'],
- report_folder=file_config['report_dir'])
- elif frame_type == FrameworkType.TENSORFLOW.value:
- check_params = ['input_nodes', 'output_nodes']
- check_params_exist(check_params, file_config)
- graph_based_converter_tf_to_ms(graph_path=graph_path,
- input_nodes=input_nodes,
- output_nodes=file_config['output_nodes'],
- output_folder=file_config['outfile_dir'],
- report_folder=file_config['report_dir'])
- else:
- error_msg = "Get UNSUPPORTED model."
- error = UnknownModelError(error_msg)
- raise error
-
-
- def check_params_exist(params: list, config):
- """Check params exist."""
- miss_param_list = ''
- for param in params:
- if not config.get(param) or not config[param]:
- miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
-
- if miss_param_list:
- raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")
|