You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import sys
  18. from typing import List
  19. from importlib import import_module
  20. from importlib.util import find_spec
  21. from functools import partial
  22. from google.protobuf.internal import api_implementation
  23. from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
  24. from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
  25. save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
  26. from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
  27. ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER
  28. from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
  29. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  30. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  31. from mindinsight.mindconverter.common.exceptions import GraphInitError, SourceFilesSaveError, \
  32. BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError, \
  33. BadParamError
  34. from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
  35. check_common_dependency_integrity = partial(check_dependency_integrity,
  36. "onnx", "onnxruntime", "onnxoptimizer")
  37. def onnx_lib_version_satisfied():
  38. """Check onnx libs version whether is satisfied."""
  39. onnx = import_module("onnx")
  40. ort = import_module("onnxruntime")
  41. optimizer = import_module("onnxoptimizer.version")
  42. if not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
  43. log_console.warning(f"onnxruntime's version should be greater than {ONNXRUNTIME_MIN_VER}, "
  44. f"however current version is {ort.__version__}.")
  45. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  46. or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER):
  47. return False
  48. return True
  49. def _print_error(err):
  50. """Print error to stdout and record it."""
  51. log.error(err)
  52. log_console.error(str(err))
  53. def onnx_installation_validation(func):
  54. """
  55. Validate args of func.
  56. Args:
  57. func (type): Function.
  58. Returns:
  59. type, inner function.
  60. """
  61. def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
  62. output_folder: str, report_folder: str = None):
  63. # Check whether onnx is installed.
  64. error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
  65. f"are required when using graph based scripts converter or ONNX conversion."
  66. if not onnx_satisfied() or not check_common_dependency_integrity():
  67. _print_error(RuntimeIntegrityError(error_info))
  68. sys.exit(0)
  69. if not onnx_lib_version_satisfied():
  70. _print_error(RuntimeIntegrityError(error_info))
  71. sys.exit(0)
  72. func(graph_path=graph_path,
  73. input_nodes=input_nodes, output_nodes=output_nodes,
  74. output_folder=output_folder, report_folder=report_folder)
  75. return _f
  76. def _check_tf_installation():
  77. """
  78. Check whether TensorFlow was installed.
  79. Returns:
  80. bool, true or false.
  81. """
  82. return find_spec("tensorflow") or find_spec("tensorflow-gpu")
  83. def tf_installation_validation(func):
  84. """
  85. Validate args of func.
  86. Args:
  87. func (type): Function.
  88. Returns:
  89. type, inner function.
  90. """
  91. def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
  92. output_folder: str, report_folder: str):
  93. not_integral_error = RuntimeIntegrityError(
  94. f"TensorFlow, "
  95. f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
  96. f"are required when using graph based scripts converter for TensorFlow conversion."
  97. )
  98. # Check whether tensorflow is installed.
  99. if not _check_tf_installation() or not onnx_satisfied():
  100. _print_error(not_integral_error)
  101. sys.exit(0)
  102. if not any([check_common_dependency_integrity("tensorflow"),
  103. check_common_dependency_integrity("tensorflow-gpu")]):
  104. _print_error(not_integral_error)
  105. sys.exit(0)
  106. tf2onnx = import_module("tf2onnx")
  107. if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
  108. or not onnx_lib_version_satisfied():
  109. _print_error(not_integral_error)
  110. sys.exit(0)
  111. func(graph_path=graph_path,
  112. input_nodes=input_nodes, output_nodes=output_nodes,
  113. output_folder=output_folder, report_folder=report_folder)
  114. return _f
  115. def _extract_model_name(model_path):
  116. """
  117. Extract model name from model path.
  118. Args:
  119. model_path (str): Path of Converted model.
  120. Returns:
  121. str, name of Converted model.
  122. """
  123. base_path = os.path.basename(model_path)
  124. model_name = '.'.join(base_path.split('.')[:-1])
  125. return model_name
  126. @onnx_installation_validation
  127. @GraphInitError.uniform_catcher()
  128. @SourceFilesSaveError.uniform_catcher()
  129. @GeneratorError.uniform_catcher()
  130. def graph_based_converter_onnx_to_ms(graph_path: str,
  131. input_nodes: dict, output_nodes: List[str],
  132. output_folder: str, report_folder: str = None):
  133. """
  134. ONNX to MindSpore based on Graph.
  135. Args:
  136. graph_path (str): Graph file path.
  137. input_nodes (dict): Input node(s) of the model.
  138. output_nodes (list[str]): Output node(s) of the model.
  139. output_folder (str): Output folder.
  140. report_folder (str): Report output folder path.
  141. """
  142. graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
  143. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  144. model_name = _extract_model_name(graph_path)
  145. code_fragments = generator_inst.generate()
  146. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  147. # Release global context.
  148. GlobalContext.release()
  149. @tf_installation_validation
  150. @GraphInitError.uniform_catcher()
  151. @TfRuntimeError.uniform_catcher()
  152. @SourceFilesSaveError.uniform_catcher()
  153. @GeneratorError.uniform_catcher()
  154. def graph_based_converter_tf_to_ms(graph_path: str,
  155. input_nodes: dict, output_nodes: List[str],
  156. output_folder: str, report_folder: str = None):
  157. """
  158. Tensorflow to MindSpore based on Graph.
  159. Args:
  160. graph_path (str): Graph file path.
  161. input_nodes (dict): Input node(s) of the model.
  162. output_nodes (list[str]): Output node(s) of the model.
  163. output_folder (str): Output folder.
  164. report_folder (str): Report output folder path.
  165. """
  166. # Close unnecessary log.
  167. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  168. graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
  169. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  170. model_name = _extract_model_name(graph_path)
  171. code_fragments = generator_inst.generate()
  172. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  173. # Release global context.
  174. GlobalContext.release()
  175. @BaseConverterError.uniform_catcher()
  176. def main_graph_base_converter(file_config):
  177. """
  178. The entrance for converter, script files will be converted.
  179. Args:
  180. file_config (dict): The config of file which to convert.
  181. """
  182. if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp':
  183. log_console.warning("Protobuf is currently implemented in \"Python\". "
  184. "The conversion process may take a long time. "
  185. "Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.")
  186. graph_path = file_config['model_file']
  187. frame_type = get_framework_type(graph_path)
  188. if not file_config.get("shape"):
  189. raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")
  190. check_params = ['input_nodes', 'output_nodes']
  191. check_params_exist(check_params, file_config)
  192. if len(file_config['shape']) != len(file_config.get("input_nodes", [])):
  193. raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
  194. "and no redundant node in `--input_nodes`.")
  195. input_nodes = dict()
  196. for shape, node in zip(file_config['shape'], file_config['input_nodes']):
  197. input_nodes[node] = shape
  198. if frame_type == FrameworkType.ONNX.value:
  199. graph_based_converter_onnx_to_ms(graph_path=graph_path,
  200. input_nodes=input_nodes,
  201. output_nodes=file_config['output_nodes'],
  202. output_folder=file_config['outfile_dir'],
  203. report_folder=file_config['report_dir'])
  204. elif frame_type == FrameworkType.TENSORFLOW.value:
  205. graph_based_converter_tf_to_ms(graph_path=graph_path,
  206. input_nodes=input_nodes,
  207. output_nodes=file_config['output_nodes'],
  208. output_folder=file_config['outfile_dir'],
  209. report_folder=file_config['report_dir'])
  210. else:
  211. error_msg = "Get UNSUPPORTED model."
  212. error = UnknownModelError(error_msg)
  213. raise error
  214. def check_params_exist(params: list, config):
  215. """Check params exist."""
  216. miss_param_list = ''
  217. for param in params:
  218. if not config.get(param) or not config[param]:
  219. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  220. if miss_param_list:
  221. raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")