You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import re
  18. import argparse
  19. from importlib import import_module
  20. from importlib.util import find_spec
  21. import mindinsight
  22. from mindinsight.mindconverter.common.log import logger as log
  23. from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
  24. BINARY_HEADER_PYTORCH_BITS
  25. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  26. from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel
  27. from mindinsight.utils.exceptions import ParamMissError
  28. permissions = os.R_OK | os.W_OK | os.X_OK
  29. os.umask(permissions << 3 | permissions)
  30. parser = argparse.ArgumentParser(
  31. prog="MindConverter",
  32. description="Graph based MindConverter CLI entry point (version: {})".format(
  33. mindinsight.__version__)
  34. )
  35. parser.add_argument("--graph", type=str, required=True,
  36. help="Third party framework's graph path.")
  37. parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
  38. help="Input shape of the model.")
  39. parser.add_argument("--ckpt", type=str, required=False,
  40. help="Third party framework's checkpoint path.")
  41. parser.add_argument("--output", type=str, required=True,
  42. help="Generated scripts output folder path.")
  43. parser.add_argument("--report", type=str, required=False,
  44. help="Generated reports output folder path.")
  45. def torch_installation_validation(func):
  46. """
  47. Validate args of func.
  48. Args:
  49. func (type): Function.
  50. Returns:
  51. type, inner function.
  52. """
  53. def _f(graph_path: str, sample_shape: tuple,
  54. output_folder: str, report_folder: str = None):
  55. # Check whether pytorch is installed.
  56. if not find_spec("torch"):
  57. error = ModuleNotFoundError("PyTorch is required when using graph based "
  58. "scripts converter, and PyTorch vision must "
  59. "be consisted with model generation runtime.")
  60. log.error(str(error))
  61. log.exception(error)
  62. raise error
  63. func(graph_path=graph_path, sample_shape=sample_shape,
  64. output_folder=output_folder, report_folder=report_folder)
  65. return _f
  66. def tf_installation_validation(func):
  67. """
  68. Validate args of func.
  69. Args:
  70. func(type): Function.
  71. Returns:
  72. type, inner function.
  73. """
  74. def _f(graph_path: str, sample_shape: tuple,
  75. output_folder: str, report_folder: str = None,
  76. input_nodes: str = None, output_nodes: str = None):
  77. # Check whether tensorflow is installed.
  78. if not find_spec("tensorflow") or not find_spec("tf2onnx"):
  79. error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using "
  80. "graph based scripts converter.")
  81. log.error(str(error))
  82. raise error
  83. func(graph_path=graph_path, sample_shape=sample_shape,
  84. output_folder=output_folder, report_folder=report_folder,
  85. input_nodes=input_nodes, output_nodes=output_nodes)
  86. return _f
  87. def _extract_model_name(model_path):
  88. """
  89. Extract model name from model path.
  90. Args:
  91. model_path(str): Path of Converted model.
  92. Returns:
  93. str: Name of Converted model.
  94. """
  95. model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1]
  96. return model_name
  97. @torch_installation_validation
  98. def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
  99. output_folder: str, report_folder: str = None):
  100. """
  101. Pytoch to MindSpore based on Graph.
  102. Args:
  103. graph_path (str): Graph file path.
  104. sample_shape (tuple): Input shape of the model.
  105. output_folder (str): Output folder.
  106. report_folder (str): Report output folder path.
  107. """
  108. third_party_graph_module = import_module(
  109. 'mindinsight.mindconverter.graph_based_converter.third_party_graph')
  110. hierarchical_tree_module = import_module(
  111. 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
  112. cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
  113. cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
  114. graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)
  115. try:
  116. hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)
  117. except Exception as e:
  118. log.exception(e)
  119. log.error("Error occur when create hierarchical tree.")
  120. raise NodeTypeNotSupport("This model is not supported now.")
  121. model_name = _extract_model_name(graph_path)
  122. hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
  123. model_name=model_name,
  124. report_folder=report_folder)
  125. @tf_installation_validation
  126. def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
  127. input_nodes: str, output_nodes: str,
  128. output_folder: str, report_folder: str = None):
  129. """
  130. Tensorflow to MindSpore based on Graph.
  131. Args:
  132. graph_path(str): Graph file path.
  133. sample_shape(tuple): Input shape of the model.
  134. input_nodes(str): Input node(s) of the model.
  135. output_nodes(str): Output node(s) of the model.
  136. output_folder(str): Output folder.
  137. report_folder(str): Report output folder path.
  138. """
  139. third_party_graph_module = import_module(
  140. 'mindinsight.mindconverter.graph_based_converter.third_party_graph')
  141. hierarchical_tree_module = import_module(
  142. 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
  143. cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
  144. cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
  145. # Close unnecessary log.
  146. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  147. graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
  148. input_nodes=input_nodes, output_nodes=output_nodes)
  149. try:
  150. hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj)
  151. except Exception as e:
  152. log.exception(e)
  153. log.error("Error occur when create hierarchical tree.")
  154. raise NodeTypeNotSupport("This model is not supported now.")
  155. model_name = _extract_model_name(graph_path)
  156. hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
  157. model_name=model_name,
  158. report_folder=report_folder,
  159. scope_name_map=scope_name_map)
  160. def main_graph_base_converter(file_config):
  161. """
  162. The entrance for converter, script files will be converted.
  163. Args:
  164. file_config (dict): The config of file which to convert.
  165. """
  166. graph_path = file_config['model_file']
  167. frame_type = get_framework_type(graph_path)
  168. if frame_type == FrameworkType.PYTORCH.value:
  169. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  170. sample_shape=file_config['shape'],
  171. output_folder=file_config['outfile_dir'],
  172. report_folder=file_config['report_dir'])
  173. elif frame_type == FrameworkType.TENSORFLOW.value:
  174. check_params = ['input_nodes', 'output_nodes']
  175. check_params_exist(check_params, file_config)
  176. graph_based_converter_tf_to_ms(graph_path=graph_path,
  177. sample_shape=file_config['shape'],
  178. input_nodes=file_config['input_nodes'],
  179. output_nodes=file_config['output_nodes'],
  180. output_folder=file_config['outfile_dir'],
  181. report_folder=file_config['report_dir'])
  182. else:
  183. error_msg = "Get UNSUPPORTED model."
  184. error = UnknownModel(error_msg)
  185. log.error(str(error))
  186. log.exception(error)
  187. raise error
  188. def get_framework_type(model_path):
  189. """Get framework type."""
  190. try:
  191. with open(model_path, 'rb') as f:
  192. if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
  193. framework_type = FrameworkType.PYTORCH.value
  194. else:
  195. framework_type = FrameworkType.TENSORFLOW.value
  196. except IOError:
  197. error_msg = "Get UNSUPPORTED model."
  198. error = UnknownModel(error_msg)
  199. log.error(str(error))
  200. log.exception(error)
  201. raise error
  202. return framework_type
  203. def check_params_exist(params: list, config):
  204. """Check params exist."""
  205. miss_param_list = ''
  206. for param in params:
  207. if not config.get(param) or not config[param]:
  208. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  209. if miss_param_list:
  210. raise ParamMissError(miss_param_list)