You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type
  22. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, \
  23. ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType
  24. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  25. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  26. class ArgsCheck:
  27. """Args check."""
  28. @staticmethod
  29. def check_repeated(namespace, dest, default, option_string, parser_in):
  30. """Check repeated."""
  31. if getattr(namespace, dest, default) is not default:
  32. parser_in.error(f'Parameter `{option_string}` is set repeatedly.')
  33. class FileDirAction(argparse.Action):
  34. """File directory action class definition."""
  35. @staticmethod
  36. def check_path(parser_in, values, option_string=None):
  37. """
  38. Check argument for file path.
  39. Args:
  40. parser_in (ArgumentParser): Passed-in argument parser.
  41. values (object): Argument values with type depending on argument definition.
  42. option_string (str): Optional string for specific argument name. Default: None.
  43. """
  44. outfile = values
  45. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  46. parser_in.error(
  47. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  48. if outfile.startswith('~'):
  49. outfile = os.path.realpath(os.path.expanduser(outfile))
  50. if not outfile.startswith('/'):
  51. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  52. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  53. parser_in.error(f'{option_string} {outfile} not accessible')
  54. return outfile
  55. def __call__(self, parser_in, namespace, values, option_string=None):
  56. """
  57. Inherited __call__ method from argparse.Action.
  58. Args:
  59. parser_in (ArgumentParser): Passed-in argument parser.
  60. namespace (Namespace): Namespace object to hold arguments.
  61. values (object): Argument values with type depending on argument definition.
  62. option_string (str): Optional string for specific argument name. Default: None.
  63. """
  64. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  65. outfile_dir = self.check_path(parser_in, values, option_string)
  66. if os.path.isfile(outfile_dir):
  67. parser_in.error(f'{option_string} {outfile_dir} is a file')
  68. setattr(namespace, self.dest, outfile_dir)
  69. class OutputDirAction(argparse.Action):
  70. """File directory action class definition."""
  71. def __call__(self, parser_in, namespace, values, option_string=None):
  72. """
  73. Inherited __call__ method from argparse.Action.
  74. Args:
  75. parser_in (ArgumentParser): Passed-in argument parser.
  76. namespace (Namespace): Namespace object to hold arguments.
  77. values (object): Argument values with type depending on argument definition.
  78. option_string (str): Optional string for specific argument name. Default: None.
  79. """
  80. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  81. output = values
  82. if len(output) > ARGUMENT_LENGTH_LIMIT:
  83. parser_in.error(
  84. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  85. if output.startswith('~'):
  86. output = os.path.realpath(os.path.expanduser(output))
  87. if not output.startswith('/'):
  88. output = os.path.realpath(os.path.join(os.getcwd(), output))
  89. if os.path.exists(output):
  90. if not os.access(output, os.R_OK):
  91. parser_in.error(f'{option_string} {output} not accessible')
  92. if os.path.isfile(output):
  93. parser_in.error(f'{option_string} {output} is a file')
  94. setattr(namespace, self.dest, output)
  95. class ProjectPathAction(argparse.Action):
  96. """Project directory action class definition."""
  97. def __call__(self, parser_in, namespace, values, option_string=None):
  98. """
  99. Inherited __call__ method from argparse.Action.
  100. Args:
  101. parser_in (ArgumentParser): Passed-in argument parser.
  102. namespace (Namespace): Namespace object to hold arguments.
  103. values (object): Argument values with type depending on argument definition.
  104. option_string (str): Optional string for specific argument name. Default: None.
  105. """
  106. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  107. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  108. if not os.path.exists(outfile_dir):
  109. parser_in.error(f'{option_string} {outfile_dir} not exists')
  110. if not os.path.isdir(outfile_dir):
  111. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  112. setattr(namespace, self.dest, outfile_dir)
  113. class InFileAction(argparse.Action):
  114. """Input File action class definition."""
  115. def __call__(self, parser_in, namespace, values, option_string=None):
  116. """
  117. Inherited __call__ method from argparse.Action.
  118. Args:
  119. parser_in (ArgumentParser): Passed-in argument parser.
  120. namespace (Namespace): Namespace object to hold arguments.
  121. values (object): Argument values with type depending on argument definition.
  122. option_string (str): Optional string for specific argument name. Default: None.
  123. """
  124. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  125. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  126. if not os.path.exists(outfile_dir):
  127. parser_in.error(f'{option_string} {outfile_dir} not exists')
  128. if not os.path.isfile(outfile_dir):
  129. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  130. if not os.path.basename(outfile_dir).endswith("py"):
  131. parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')
  132. setattr(namespace, self.dest, outfile_dir)
  133. class ModelFileAction(argparse.Action):
  134. """Model File action class definition."""
  135. def __call__(self, parser_in, namespace, values, option_string=None):
  136. """
  137. Inherited __call__ method from argparse.Action.
  138. Args:
  139. parser_in (ArgumentParser): Passed-in argument parser.
  140. namespace (Namespace): Namespace object to hold arguments.
  141. values (object): Argument values with type depending on argument definition.
  142. option_string (str): Optional string for specific argument name. Default: None.
  143. """
  144. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  145. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  146. if not os.path.exists(outfile_dir):
  147. parser_in.error(f'{option_string} {outfile_dir} not exists')
  148. if not os.path.isfile(outfile_dir):
  149. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  150. frame_type = get_framework_type(outfile_dir)
  151. if frame_type == FrameworkType.UNKNOWN.value:
  152. parser_in.error(f'{option_string} {outfile_dir} should be an valid '
  153. f'TensorFlow pb or PyTorch pth model file')
  154. setattr(namespace, self.dest, outfile_dir)
  155. class LogFileAction(argparse.Action):
  156. """Log file action class definition."""
  157. def __call__(self, parser_in, namespace, values, option_string=None):
  158. """
  159. Inherited __call__ method from FileDirAction.
  160. Args:
  161. parser_in (ArgumentParser): Passed-in argument parser.
  162. namespace (Namespace): Namespace object to hold arguments.
  163. values (object): Argument values with type depending on argument definition.
  164. option_string (str): Optional string for specific argument name. Default: None.
  165. """
  166. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  167. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  168. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  169. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  170. setattr(namespace, self.dest, outfile_dir)
  171. class ShapeAction(argparse.Action):
  172. """Shape action class definition."""
  173. def __call__(self, parser_in, namespace, values, option_string=None):
  174. """
  175. Inherited __call__ method from FileDirAction.
  176. Args:
  177. parser_in (ArgumentParser): Passed-in argument parser.
  178. namespace (Namespace): Namespace object to hold arguments.
  179. values (list): Argument values with type depending on argument definition.
  180. option_string (str): Optional string for specific argument name. Default: None.
  181. """
  182. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  183. def _convert_to_int(shape_list):
  184. return [int(num_shape) for num_shape in shape_list.split(',')]
  185. try:
  186. if len(values) > ARGUMENT_NUM_LIMIT:
  187. parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.")
  188. in_shape = []
  189. for v in values:
  190. shape = _convert_to_int(v)
  191. if len(shape) > ARGUMENT_LEN_LIMIT:
  192. parser_in.error(
  193. f"The length of {option_string} {shape} should be no more than {ARGUMENT_LEN_LIMIT}.")
  194. in_shape.append(shape)
  195. setattr(namespace, self.dest, in_shape)
  196. except ValueError:
  197. parser_in.error(
  198. f"{option_string} {values} should be list of integers split by ',', check it please.")
  199. class NodeAction(argparse.Action):
  200. """Node action class definition."""
  201. def __call__(self, parser_in, namespace, values, option_string=None):
  202. """
  203. Inherited __call__ method from FileDirAction.
  204. Args:
  205. parser_in (ArgumentParser): Passed-in argument parser.
  206. namespace (Namespace): Namespace object to hold arguments.
  207. values (list): Argument values with type depending on argument definition.
  208. option_string (str): Optional string for specific argument name. Default: None.
  209. """
  210. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  211. if len(values) > ARGUMENT_NUM_LIMIT:
  212. parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.")
  213. deduplicated = set()
  214. abnormal_nodes = []
  215. for v in values:
  216. if len(v) > ARGUMENT_LENGTH_LIMIT:
  217. parser_in.error(
  218. f"The length of {option_string} {v} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  219. )
  220. if v in deduplicated:
  221. abnormal_nodes.append(v)
  222. continue
  223. deduplicated.add(v)
  224. if abnormal_nodes:
  225. parser_in.error(f"{', '.join(abnormal_nodes)} {'is' if len(abnormal_nodes) == 1 else 'are'} duplicated.")
  226. setattr(namespace, self.dest, values)
  227. parser = argparse.ArgumentParser(
  228. prog='mindconverter',
  229. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
  230. allow_abbrev=False)
  231. parser.add_argument(
  232. '--version',
  233. action='version',
  234. version='%(prog)s ({})'.format(mindinsight.__version__))
  235. parser.add_argument(
  236. '--in_file',
  237. type=str,
  238. action=InFileAction,
  239. required=False,
  240. default=None,
  241. help="""
  242. Specify path for script file to use AST schema to
  243. do script conversation.
  244. """)
  245. parser.add_argument(
  246. '--model_file',
  247. type=str,
  248. action=ModelFileAction,
  249. required=False,
  250. help="""
  251. PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path
  252. is expected to do script generation based on graph schema. When
  253. `--in_file` and `--model_file` are both provided,
  254. use AST schema as default.
  255. """)
  256. parser.add_argument(
  257. '--shape',
  258. type=str,
  259. action=ShapeAction,
  260. default=None,
  261. required=False,
  262. nargs="+",
  263. help="""
  264. Optional, expected input tensor shape of
  265. `--model_file`. It is required when use graph based
  266. schema. Both order and number should be consistent with `--input_nodes`.
  267. Usage: --shape 1,512 1,512
  268. """)
  269. parser.add_argument(
  270. '--input_nodes',
  271. type=str,
  272. action=NodeAction,
  273. default=None,
  274. required=False,
  275. nargs="+",
  276. help="""
  277. Optional, input node(s) name of `--model_file`. It is required when use TensorFlow and ONNX model.
  278. Both order and number should be consistent with `--shape`. Usage: --input_nodes input_1:0 input_2:0
  279. """)
  280. parser.add_argument(
  281. '--output_nodes',
  282. type=str,
  283. action=NodeAction,
  284. default=None,
  285. required=False,
  286. nargs="+",
  287. help="""
  288. Optional, output node(s) name of `--model_file`. It is required when use TensorFlow and ONNX model.
  289. Usage: --output_nodes output_1:0 output_2:0
  290. """)
  291. parser.add_argument(
  292. '--output',
  293. type=str,
  294. action=OutputDirAction,
  295. default=os.path.join(os.getcwd(), 'output'),
  296. help="""
  297. Optional, specify path for converted script file
  298. directory. Default output directory is `output` folder
  299. in the current working directory.
  300. """)
  301. parser.add_argument(
  302. '--report',
  303. type=str,
  304. action=LogFileAction,
  305. default=None,
  306. help="""
  307. Optional, specify report directory. Default is
  308. converted script directory.
  309. """)
  310. parser.add_argument(
  311. '--project_path',
  312. type=str,
  313. action=ProjectPathAction,
  314. required=False,
  315. default=None,
  316. help="""
  317. Optional, PyTorch scripts project path. If PyTorch
  318. project is not in PYTHONPATH, please assign
  319. `--project_path` when use graph based schema.
  320. Usage: --project_path ~/script_file/
  321. """)
  322. def cli_entry():
  323. """Entry point for mindconverter CLI."""
  324. permissions = os.R_OK | os.W_OK | os.X_OK
  325. os.umask(permissions << 3 | permissions)
  326. argv = sys.argv[1:]
  327. if not argv:
  328. argv = ['-h']
  329. args = parser.parse_args(argv)
  330. else:
  331. args = parser.parse_args()
  332. mode = permissions << 6
  333. os.makedirs(args.output, mode=mode, exist_ok=True)
  334. if args.report is None:
  335. args.report = args.output
  336. os.makedirs(args.report, mode=mode, exist_ok=True)
  337. _run(args.in_file, args.model_file,
  338. args.shape,
  339. args.input_nodes, args.output_nodes,
  340. args.output, args.report,
  341. args.project_path)
  342. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  343. """
  344. Run converter command.
  345. Args:
  346. in_files (str): The file path or directory to convert.
  347. model_file(str): The pytorch .pth to convert on graph based schema.
  348. shape(list): The input tensor shape of module_file.
  349. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  350. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  351. out_dir (str): The output directory to save converted file.
  352. report (str): The report file path.
  353. project_path(str): Pytorch scripts project path.
  354. """
  355. if in_files:
  356. files_config = {
  357. 'root_path': in_files,
  358. 'in_files': [],
  359. 'outfile_dir': out_dir,
  360. 'report_dir': report if report else out_dir
  361. }
  362. if os.path.isfile(in_files):
  363. files_config['root_path'] = os.path.dirname(in_files)
  364. files_config['in_files'] = [in_files]
  365. else:
  366. for root_dir, _, files in os.walk(in_files):
  367. for file in files:
  368. files_config['in_files'].append(os.path.join(root_dir, file))
  369. main(files_config)
  370. log_console.info("MindConverter: conversion is completed.")
  371. elif model_file:
  372. file_config = {
  373. 'model_file': model_file,
  374. 'shape': shape if shape else [],
  375. 'input_nodes': input_nodes,
  376. 'output_nodes': output_nodes,
  377. 'outfile_dir': out_dir,
  378. 'report_dir': report if report else out_dir
  379. }
  380. if project_path:
  381. paths = sys.path
  382. if project_path not in paths:
  383. sys.path.append(project_path)
  384. main_graph_base_converter(file_config)
  385. log_console.info("MindConverter: conversion is completed.")
  386. else:
  387. error_msg = "`--in_file` and `--model_file` should be set at least one."
  388. error = FileNotFoundError(error_msg)
  389. log.error(str(error))
  390. log_console.error(f"mindconverter: error: {str(error)}")
  391. sys.exit(-1)