You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import re
  18. import argparse
  19. import sys
  20. from importlib import import_module
  21. from importlib.util import find_spec
  22. import mindinsight
  23. from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
  24. save_code_file_and_report
  25. from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
  26. BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
  27. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  28. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  29. from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \
  30. BaseConverterFail, UnknownModel, GeneratorFail, TfRuntimeError
  31. from mindinsight.utils.exceptions import ParamMissError
  32. permissions = os.R_OK | os.W_OK | os.X_OK
  33. os.umask(permissions << 3 | permissions)
  34. parser = argparse.ArgumentParser(
  35. prog="MindConverter",
  36. description="Graph based MindConverter CLI entry point (version: {})".format(
  37. mindinsight.__version__)
  38. )
  39. parser.add_argument("--graph", type=str, required=True,
  40. help="Third party framework's graph path.")
  41. parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
  42. help="Input shape of the model.")
  43. parser.add_argument("--ckpt", type=str, required=False,
  44. help="Third party framework's checkpoint path.")
  45. parser.add_argument("--output", type=str, required=True,
  46. help="Generated scripts output folder path.")
  47. parser.add_argument("--report", type=str, required=False,
  48. help="Generated reports output folder path.")
  49. def torch_installation_validation(func):
  50. """
  51. Validate args of func.
  52. Args:
  53. func (type): Function.
  54. Returns:
  55. type, inner function.
  56. """
  57. def _f(graph_path: str, sample_shape: tuple,
  58. output_folder: str, report_folder: str = None):
  59. # Check whether pytorch is installed.
  60. if not find_spec("torch"):
  61. error = ModuleNotFoundError("PyTorch is required when using graph based "
  62. "scripts converter, and PyTorch vision must "
  63. "be consisted with model generation runtime.")
  64. log.error(str(error))
  65. detail_info = f"Error detail: {str(error)}"
  66. log_console.error(str(error))
  67. log_console.error(detail_info)
  68. sys.exit(0)
  69. func(graph_path=graph_path, sample_shape=sample_shape,
  70. output_folder=output_folder, report_folder=report_folder)
  71. return _f
  72. def tf_installation_validation(func):
  73. """
  74. Validate args of func.
  75. Args:
  76. func(type): Function.
  77. Returns:
  78. type, inner function.
  79. """
  80. def _f(graph_path: str, sample_shape: tuple,
  81. output_folder: str, report_folder: str = None,
  82. input_nodes: str = None, output_nodes: str = None):
  83. # Check whether tensorflow is installed.
  84. if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \
  85. or not find_spec("onnxruntime"):
  86. error = ModuleNotFoundError(
  87. f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
  88. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
  89. f"based scripts converter for TensorFlow conversion."
  90. )
  91. log.error(str(error))
  92. detail_info = f"Error detail: {str(error)}"
  93. log_console.error(str(error))
  94. log_console.error(detail_info)
  95. sys.exit(0)
  96. onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
  97. ort = import_module("onnxruntime")
  98. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  99. or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
  100. or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
  101. error = ModuleNotFoundError(
  102. f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
  103. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
  104. f"based scripts converter for TensorFlow conversion."
  105. )
  106. log.error(str(error))
  107. detail_info = f"Error detail: {str(error)}"
  108. log_console.error(str(error))
  109. log_console.error(detail_info)
  110. sys.exit(0)
  111. func(graph_path=graph_path, sample_shape=sample_shape,
  112. output_folder=output_folder, report_folder=report_folder,
  113. input_nodes=input_nodes, output_nodes=output_nodes)
  114. return _f
  115. def _extract_model_name(model_path):
  116. """
  117. Extract model name from model path.
  118. Args:
  119. model_path(str): Path of Converted model.
  120. Returns:
  121. str: Name of Converted model.
  122. """
  123. model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1]
  124. return model_name
  125. @torch_installation_validation
  126. @GraphInitFail.uniform_catcher("Error occurred when init graph object.")
  127. @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
  128. @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
  129. @GeneratorFail.uniform_catcher("Error occurred when generate code.")
  130. def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
  131. output_folder: str, report_folder: str = None):
  132. """
  133. Pytoch to MindSpore based on Graph.
  134. Args:
  135. graph_path (str): Graph file path.
  136. sample_shape (tuple): Input shape of the model.
  137. output_folder (str): Output folder.
  138. report_folder (str): Report output folder path.
  139. """
  140. third_party_graph_module = import_module(
  141. 'mindinsight.mindconverter.graph_based_converter.third_party_graph')
  142. hierarchical_tree_module = import_module(
  143. 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
  144. cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
  145. cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
  146. graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)
  147. hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)
  148. model_name = _extract_model_name(graph_path)
  149. hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
  150. model_name=model_name,
  151. report_folder=report_folder)
  152. @tf_installation_validation
  153. @GraphInitFail.uniform_catcher("Error occurred when init graph object.")
  154. @TfRuntimeError.uniform_catcher("Error occurred when init graph, TensorFlow runtime error.")
  155. @TreeCreateFail.uniform_catcher("Error occurred when create hierarchical tree.")
  156. @SourceFilesSaveFail.uniform_catcher("Error occurred when save source files.")
  157. @GeneratorFail.uniform_catcher("Error occurred when generate code.")
  158. def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
  159. input_nodes: str, output_nodes: str,
  160. output_folder: str, report_folder: str = None):
  161. """
  162. Tensorflow to MindSpore based on Graph.
  163. Args:
  164. graph_path(str): Graph file path.
  165. sample_shape(tuple): Input shape of the model.
  166. input_nodes(str): Input node(s) of the model.
  167. output_nodes(str): Output node(s) of the model.
  168. output_folder(str): Output folder.
  169. report_folder(str): Report output folder path.
  170. """
  171. third_party_graph_module = import_module(
  172. 'mindinsight.mindconverter.graph_based_converter.third_party_graph')
  173. cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
  174. batch_add_nodes = getattr(import_module('mindinsight.mindconverter.graph_based_converter.generator'),
  175. "batch_add_nodes")
  176. # Close unnecessary log.
  177. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  178. graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
  179. input_nodes=input_nodes, output_nodes=output_nodes)
  180. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  181. model_name = _extract_model_name(graph_path)
  182. code_fragments = generator_inst.generate()
  183. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  184. @BaseConverterFail.uniform_catcher("Failed to start base converter.")
  185. def main_graph_base_converter(file_config):
  186. """
  187. The entrance for converter, script files will be converted.
  188. Args:
  189. file_config (dict): The config of file which to convert.
  190. """
  191. graph_path = file_config['model_file']
  192. frame_type = get_framework_type(graph_path)
  193. if frame_type == FrameworkType.PYTORCH.value:
  194. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  195. sample_shape=file_config['shape'],
  196. output_folder=file_config['outfile_dir'],
  197. report_folder=file_config['report_dir'])
  198. elif frame_type == FrameworkType.TENSORFLOW.value:
  199. check_params = ['input_nodes', 'output_nodes']
  200. check_params_exist(check_params, file_config)
  201. graph_based_converter_tf_to_ms(graph_path=graph_path,
  202. sample_shape=file_config['shape'],
  203. input_nodes=file_config['input_nodes'],
  204. output_nodes=file_config['output_nodes'],
  205. output_folder=file_config['outfile_dir'],
  206. report_folder=file_config['report_dir'])
  207. else:
  208. error_msg = "Get UNSUPPORTED model."
  209. error = UnknownModel(error_msg)
  210. log.error(str(error))
  211. raise error
  212. def get_framework_type(model_path):
  213. """Get framework type."""
  214. try:
  215. with open(model_path, 'rb') as f:
  216. if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
  217. framework_type = FrameworkType.PYTORCH.value
  218. else:
  219. framework_type = FrameworkType.TENSORFLOW.value
  220. except IOError:
  221. error_msg = "Get UNSUPPORTED model."
  222. error = UnknownModel(error_msg)
  223. log.error(str(error))
  224. raise error
  225. return framework_type
  226. def check_params_exist(params: list, config):
  227. """Check params exist."""
  228. miss_param_list = ''
  229. for param in params:
  230. if not config.get(param) or not config[param]:
  231. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  232. if miss_param_list:
  233. error = ParamMissError(miss_param_list)
  234. log.error(str(error))
  235. raise error