You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import sys
  18. from importlib import import_module
  19. from importlib.util import find_spec
  20. from functools import partial
  21. from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
  22. from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
  23. save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
  24. from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
  25. ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER
  26. from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
  27. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  28. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  29. from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
  30. BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError
  31. from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
  32. check_common_dependency_integrity = partial(check_dependency_integrity,
  33. "onnx", "onnxruntime", "onnxoptimizer")
  34. def onnx_lib_version_satisfied():
  35. """Check onnx libs version whether is satisfied."""
  36. onnx = import_module("onnx")
  37. ort = import_module("onnxruntime")
  38. optimizer = import_module("onnxoptimizer.version")
  39. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  40. or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
  41. or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER):
  42. return False
  43. return True
  44. def _print_error(err):
  45. """Print error to stdout and record it."""
  46. log.error(err)
  47. log_console.error("\n")
  48. log_console.error(str(err))
  49. log_console.error("\n")
  50. def torch_installation_validation(func):
  51. """
  52. Validate args of func.
  53. Args:
  54. func (type): Function.
  55. Returns:
  56. type, inner function.
  57. """
  58. def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str,
  59. output_folder: str, report_folder: str = None):
  60. # Check whether pytorch is installed.
  61. error_info = None
  62. if graph_path.endswith('.onnx'):
  63. if not onnx_satisfied() or not check_common_dependency_integrity():
  64. error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
  65. f"are required when using graph based scripts converter."
  66. else:
  67. if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity("torch"):
  68. error_info = f"PyTorch, " \
  69. f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
  70. f"are required when using graph based scripts converter, and PyTorch version must " \
  71. f"be consisted with model generation runtime."
  72. if error_info:
  73. _print_error(RuntimeIntegrityError(error_info))
  74. sys.exit(0)
  75. if not onnx_lib_version_satisfied():
  76. error = RuntimeIntegrityError(
  77. f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} "
  78. f"are required when using graph based scripts converter."
  79. )
  80. _print_error(error)
  81. sys.exit(0)
  82. func(graph_path=graph_path, sample_shape=sample_shape,
  83. input_nodes=input_nodes, output_nodes=output_nodes,
  84. output_folder=output_folder, report_folder=report_folder)
  85. return _f
  86. def _check_tf_installation():
  87. """
  88. Check whether TensorFlow was installed.
  89. Returns:
  90. bool, true or false.
  91. """
  92. return find_spec("tensorflow") or find_spec("tensorflow-gpu")
  93. def tf_installation_validation(func):
  94. """
  95. Validate args of func.
  96. Args:
  97. func(type): Function.
  98. Returns:
  99. type, inner function.
  100. """
  101. def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None,
  102. input_nodes: str = None, output_nodes: str = None):
  103. not_integral_error = RuntimeIntegrityError(
  104. f"TensorFlow, "
  105. f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
  106. f"are required when using graph based scripts converter for TensorFlow conversion."
  107. )
  108. # Check whether tensorflow is installed.
  109. if not _check_tf_installation() or not onnx_satisfied():
  110. _print_error(not_integral_error)
  111. sys.exit(0)
  112. if not any([check_common_dependency_integrity("tensorflow"),
  113. check_common_dependency_integrity("tensorflow-gpu")]):
  114. _print_error(not_integral_error)
  115. sys.exit(0)
  116. tf2onnx = import_module("tf2onnx")
  117. if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
  118. or not onnx_lib_version_satisfied():
  119. _print_error(not_integral_error)
  120. sys.exit(0)
  121. func(graph_path=graph_path, sample_shape=sample_shape,
  122. output_folder=output_folder, report_folder=report_folder,
  123. input_nodes=input_nodes, output_nodes=output_nodes)
  124. return _f
  125. def _extract_model_name(model_path):
  126. """
  127. Extract model name from model path.
  128. Args:
  129. model_path(str): Path of Converted model.
  130. Returns:
  131. str: Name of Converted model.
  132. """
  133. base_path = os.path.basename(model_path)
  134. model_name = '.'.join(base_path.split('.')[:-1])
  135. return model_name
  136. @torch_installation_validation
  137. @GraphInitError.uniform_catcher()
  138. @TreeCreationError.uniform_catcher()
  139. @SourceFilesSaveError.uniform_catcher()
  140. @GeneratorError.uniform_catcher()
  141. def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
  142. input_nodes: str, output_nodes: str,
  143. output_folder: str, report_folder: str = None):
  144. """
  145. PyTorch to MindSpore based on Graph.
  146. Args:
  147. graph_path (str): Graph file path.
  148. sample_shape (tuple): Input shape of the model.
  149. input_nodes (str): Input node(s) of the model.
  150. output_nodes (str): Output node(s) of the model.
  151. output_folder (str): Output folder.
  152. report_folder (str): Report output folder path.
  153. """
  154. graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
  155. input_nodes=input_nodes, output_nodes=output_nodes)
  156. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  157. model_name = _extract_model_name(graph_path)
  158. code_fragments = generator_inst.generate()
  159. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  160. # Release global context.
  161. GlobalContext.release()
  162. @tf_installation_validation
  163. @GraphInitError.uniform_catcher()
  164. @TfRuntimeError.uniform_catcher()
  165. @TreeCreationError.uniform_catcher()
  166. @SourceFilesSaveError.uniform_catcher()
  167. @GeneratorError.uniform_catcher()
  168. def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
  169. input_nodes: str, output_nodes: str,
  170. output_folder: str, report_folder: str = None):
  171. """
  172. Tensorflow to MindSpore based on Graph.
  173. Args:
  174. graph_path(str): Graph file path.
  175. sample_shape(tuple): Input shape of the model.
  176. input_nodes(str): Input node(s) of the model.
  177. output_nodes(str): Output node(s) of the model.
  178. output_folder(str): Output folder.
  179. report_folder(str): Report output folder path.
  180. """
  181. # Close unnecessary log.
  182. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  183. graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
  184. input_nodes=input_nodes, output_nodes=output_nodes)
  185. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  186. model_name = _extract_model_name(graph_path)
  187. code_fragments = generator_inst.generate()
  188. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  189. # Release global context.
  190. GlobalContext.release()
  191. @BaseConverterError.uniform_catcher()
  192. def main_graph_base_converter(file_config):
  193. """
  194. The entrance for converter, script files will be converted.
  195. Args:
  196. file_config (dict): The config of file which to convert.
  197. """
  198. graph_path = file_config['model_file']
  199. frame_type = get_framework_type(graph_path)
  200. if not file_config.get("shape"):
  201. raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")
  202. if frame_type == FrameworkType.PYTORCH.value:
  203. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  204. sample_shape=file_config['shape'],
  205. input_nodes=file_config['input_nodes'] if file_config['input_nodes']
  206. else 'input.1',
  207. output_nodes=file_config['output_nodes'] if file_config['output_nodes']
  208. else '',
  209. output_folder=file_config['outfile_dir'],
  210. report_folder=file_config['report_dir'])
  211. elif frame_type == FrameworkType.TENSORFLOW.value:
  212. check_params = ['input_nodes', 'output_nodes']
  213. check_params_exist(check_params, file_config)
  214. graph_based_converter_tf_to_ms(graph_path=graph_path,
  215. sample_shape=file_config['shape'],
  216. input_nodes=file_config['input_nodes'],
  217. output_nodes=file_config['output_nodes'],
  218. output_folder=file_config['outfile_dir'],
  219. report_folder=file_config['report_dir'])
  220. else:
  221. error_msg = "Get UNSUPPORTED model."
  222. error = UnknownModelError(error_msg)
  223. raise error
  224. def check_params_exist(params: list, config):
  225. """Check params exist."""
  226. miss_param_list = ''
  227. for param in params:
  228. if not config.get(param) or not config[param]:
  229. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  230. if miss_param_list:
  231. raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")