You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 14 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import multiprocessing as mp
  17. import os
  18. import re
  19. import sys
  20. from typing import List
  21. from importlib import import_module
  22. from importlib.util import find_spec
  23. from functools import partial
  24. from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
  25. from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
  26. save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
  27. from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
  28. ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER, TORCH_MIN_VER
  29. from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
  30. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  31. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  32. from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
  33. BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError, \
  34. BadParamError
  35. from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
  36. from google.protobuf.internal import api_implementation
  37. check_common_dependency_integrity = partial(check_dependency_integrity,
  38. "onnx", "onnxruntime", "onnxoptimizer")
  39. def onnx_lib_version_satisfied():
  40. """Check onnx libs version whether is satisfied."""
  41. onnx = import_module("onnx")
  42. ort = import_module("onnxruntime")
  43. optimizer = import_module("onnxoptimizer.version")
  44. if not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
  45. log_console.warning("onnxruntime's version should be greater than %s, however current version is %s.",
  46. ONNXRUNTIME_MIN_VER, ort.__version__)
  47. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  48. or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER):
  49. return False
  50. return True
  51. def _print_error(err):
  52. """Print error to stdout and record it."""
  53. log.error(err)
  54. log_console.error("\n")
  55. log_console.error(str(err))
  56. log_console.error("\n")
  57. def torch_version_satisfied(output_queue):
  58. """Check Torch version whether is satisfied."""
  59. satisfied = False
  60. pattern = r"\d+\.\d+\.\d+"
  61. torch_version = re.findall(pattern, getattr(import_module('torch'), "__version__"))
  62. if torch_version:
  63. satisfied = lib_version_satisfied(torch_version[0], TORCH_MIN_VER)
  64. output_queue.put(satisfied)
  65. def torch_installation_validation(func):
  66. """
  67. Validate args of func.
  68. Args:
  69. func (type): Function.
  70. Returns:
  71. type, inner function.
  72. """
  73. def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
  74. output_folder: str, report_folder: str = None):
  75. # Check whether pytorch is installed.
  76. error_info = None
  77. torch_version_validation = False
  78. if graph_path.endswith('.onnx'):
  79. if not onnx_satisfied() or not check_common_dependency_integrity():
  80. error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
  81. f"are required when using graph based scripts converter."
  82. else:
  83. if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity():
  84. error_info = \
  85. f"{get_third_part_lib_validation_error_info(['torch', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " \
  86. f"are required when using graph based scripts converter, and PyTorch version must " \
  87. f"be consisted with model generation runtime."
  88. if not error_info:
  89. output_queue = mp.Queue()
  90. process = mp.Process(target=torch_version_satisfied, args=(output_queue,))
  91. process.start()
  92. torch_version_validation = output_queue.get()
  93. process.join()
  94. if error_info:
  95. _print_error(RuntimeIntegrityError(error_info))
  96. sys.exit(0)
  97. if (not torch_version_validation and not graph_path.endswith('.onnx')) or not onnx_lib_version_satisfied():
  98. lib_check_list = ['onnx', 'onnxruntime', 'onnxoptimizer']
  99. if not graph_path.endswith('.onnx'):
  100. lib_check_list.insert(0, 'torch')
  101. error = RuntimeIntegrityError(
  102. f"{get_third_part_lib_validation_error_info(lib_check_list)} "
  103. f"are required when using graph based scripts converter."
  104. )
  105. _print_error(error)
  106. sys.exit(0)
  107. func(graph_path=graph_path,
  108. input_nodes=input_nodes, output_nodes=output_nodes,
  109. output_folder=output_folder, report_folder=report_folder)
  110. return _f
  111. def _check_tf_installation():
  112. """
  113. Check whether TensorFlow was installed.
  114. Returns:
  115. bool, true or false.
  116. """
  117. return find_spec("tensorflow") or find_spec("tensorflow-gpu")
  118. def tf_installation_validation(func):
  119. """
  120. Validate args of func.
  121. Args:
  122. func (type): Function.
  123. Returns:
  124. type, inner function.
  125. """
  126. def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
  127. output_folder: str, report_folder: str):
  128. not_integral_error = RuntimeIntegrityError(
  129. f"TensorFlow, "
  130. f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
  131. f"are required when using graph based scripts converter for TensorFlow conversion."
  132. )
  133. # Check whether tensorflow is installed.
  134. if not _check_tf_installation() or not onnx_satisfied():
  135. _print_error(not_integral_error)
  136. sys.exit(0)
  137. if not any([check_common_dependency_integrity("tensorflow"),
  138. check_common_dependency_integrity("tensorflow-gpu")]):
  139. _print_error(not_integral_error)
  140. sys.exit(0)
  141. tf2onnx = import_module("tf2onnx")
  142. if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
  143. or not onnx_lib_version_satisfied():
  144. _print_error(not_integral_error)
  145. sys.exit(0)
  146. func(graph_path=graph_path,
  147. input_nodes=input_nodes, output_nodes=output_nodes,
  148. output_folder=output_folder, report_folder=report_folder)
  149. return _f
  150. def _extract_model_name(model_path):
  151. """
  152. Extract model name from model path.
  153. Args:
  154. model_path (str): Path of Converted model.
  155. Returns:
  156. str, name of Converted model.
  157. """
  158. base_path = os.path.basename(model_path)
  159. model_name = '.'.join(base_path.split('.')[:-1])
  160. return model_name
  161. @torch_installation_validation
  162. @GraphInitError.uniform_catcher()
  163. @TreeCreationError.uniform_catcher()
  164. @SourceFilesSaveError.uniform_catcher()
  165. @GeneratorError.uniform_catcher()
  166. def graph_based_converter_pytorch_to_ms(graph_path: str,
  167. input_nodes: dict, output_nodes: List[str],
  168. output_folder: str, report_folder: str = None):
  169. """
  170. PyTorch to MindSpore based on Graph.
  171. Args:
  172. graph_path (str): Graph file path.
  173. input_nodes (dict): Input node(s) of the model.
  174. output_nodes (list[str]): Output node(s) of the model.
  175. output_folder (str): Output folder.
  176. report_folder (str): Report output folder path.
  177. """
  178. graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
  179. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  180. model_name = _extract_model_name(graph_path)
  181. code_fragments = generator_inst.generate()
  182. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  183. # Release global context.
  184. GlobalContext.release()
  185. @tf_installation_validation
  186. @GraphInitError.uniform_catcher()
  187. @TfRuntimeError.uniform_catcher()
  188. @TreeCreationError.uniform_catcher()
  189. @SourceFilesSaveError.uniform_catcher()
  190. @GeneratorError.uniform_catcher()
  191. def graph_based_converter_tf_to_ms(graph_path: str,
  192. input_nodes: dict, output_nodes: List[str],
  193. output_folder: str, report_folder: str = None):
  194. """
  195. Tensorflow to MindSpore based on Graph.
  196. Args:
  197. graph_path (str): Graph file path.
  198. input_nodes (dict): Input node(s) of the model.
  199. output_nodes (list[str]): Output node(s) of the model.
  200. output_folder (str): Output folder.
  201. report_folder (str): Report output folder path.
  202. """
  203. # Close unnecessary log.
  204. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  205. graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
  206. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  207. model_name = _extract_model_name(graph_path)
  208. code_fragments = generator_inst.generate()
  209. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  210. # Release global context.
  211. GlobalContext.release()
  212. @BaseConverterError.uniform_catcher()
  213. def main_graph_base_converter(file_config):
  214. """
  215. The entrance for converter, script files will be converted.
  216. Args:
  217. file_config (dict): The config of file which to convert.
  218. """
  219. if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp':
  220. log_console.warning("Protobuf is currently implemented in \"Python\". "
  221. "The conversion process may take a long time. Please use the \"C++\" backend version.")
  222. graph_path = file_config['model_file']
  223. frame_type = get_framework_type(graph_path)
  224. if not file_config.get("shape"):
  225. raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")
  226. if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \
  227. file_config.get("shape") and len(file_config.get("shape", ())) == 1:
  228. file_config['input_nodes'] = ["input.1"]
  229. else:
  230. check_params = ['input_nodes', 'output_nodes']
  231. check_params_exist(check_params, file_config)
  232. if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len(
  233. set(file_config.get("input_nodes", []))):
  234. raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
  235. "and no redundant node in `--input_nodes`.")
  236. input_nodes = dict()
  237. for shape, node in zip(file_config['shape'], file_config['input_nodes']):
  238. input_nodes[node] = shape
  239. if frame_type == FrameworkType.PYTORCH.value:
  240. if graph_path.endswith('.onnx'):
  241. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  242. input_nodes=input_nodes,
  243. output_nodes=file_config['output_nodes'],
  244. output_folder=file_config['outfile_dir'],
  245. report_folder=file_config['report_dir'])
  246. else:
  247. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  248. input_nodes=input_nodes,
  249. output_nodes=[],
  250. output_folder=file_config['outfile_dir'],
  251. report_folder=file_config['report_dir'])
  252. elif frame_type == FrameworkType.TENSORFLOW.value:
  253. graph_based_converter_tf_to_ms(graph_path=graph_path,
  254. input_nodes=input_nodes,
  255. output_nodes=file_config['output_nodes'],
  256. output_folder=file_config['outfile_dir'],
  257. report_folder=file_config['report_dir'])
  258. else:
  259. error_msg = "Get UNSUPPORTED model."
  260. error = UnknownModelError(error_msg)
  261. raise error
  262. def check_params_exist(params: list, config):
  263. """Check params exist."""
  264. miss_param_list = ''
  265. for param in params:
  266. if not config.get(param) or not config[param]:
  267. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  268. if miss_param_list:
  269. raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")