You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type
  22. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER, \
  23. FrameworkType
  24. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  25. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  26. class ArgsCheck:
  27. """Args check."""
  28. @staticmethod
  29. def check_repeated(namespace, dest, default, option_string, parser_in):
  30. """Check repeated."""
  31. if getattr(namespace, dest, default) is not default:
  32. parser_in.error(f'Parameter `{option_string}` is set repeatedly.')
  33. class FileDirAction(argparse.Action):
  34. """File directory action class definition."""
  35. @staticmethod
  36. def check_path(parser_in, values, option_string=None):
  37. """
  38. Check argument for file path.
  39. Args:
  40. parser_in (ArgumentParser): Passed-in argument parser.
  41. values (object): Argument values with type depending on argument definition.
  42. option_string (str): Optional string for specific argument name. Default: None.
  43. """
  44. outfile = values
  45. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  46. parser_in.error(
  47. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  48. if outfile.startswith('~'):
  49. outfile = os.path.realpath(os.path.expanduser(outfile))
  50. if not outfile.startswith('/'):
  51. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  52. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  53. parser_in.error(f'{option_string} {outfile} not accessible')
  54. return outfile
  55. def __call__(self, parser_in, namespace, values, option_string=None):
  56. """
  57. Inherited __call__ method from argparse.Action.
  58. Args:
  59. parser_in (ArgumentParser): Passed-in argument parser.
  60. namespace (Namespace): Namespace object to hold arguments.
  61. values (object): Argument values with type depending on argument definition.
  62. option_string (str): Optional string for specific argument name. Default: None.
  63. """
  64. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  65. outfile_dir = self.check_path(parser_in, values, option_string)
  66. if os.path.isfile(outfile_dir):
  67. parser_in.error(f'{option_string} {outfile_dir} is a file')
  68. setattr(namespace, self.dest, outfile_dir)
  69. class OutputDirAction(argparse.Action):
  70. """File directory action class definition."""
  71. def __call__(self, parser_in, namespace, values, option_string=None):
  72. """
  73. Inherited __call__ method from argparse.Action.
  74. Args:
  75. parser_in (ArgumentParser): Passed-in argument parser.
  76. namespace (Namespace): Namespace object to hold arguments.
  77. values (object): Argument values with type depending on argument definition.
  78. option_string (str): Optional string for specific argument name. Default: None.
  79. """
  80. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  81. output = values
  82. if len(output) > ARGUMENT_LENGTH_LIMIT:
  83. parser_in.error(
  84. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  85. if output.startswith('~'):
  86. output = os.path.realpath(os.path.expanduser(output))
  87. if not output.startswith('/'):
  88. output = os.path.realpath(os.path.join(os.getcwd(), output))
  89. if os.path.exists(output):
  90. if not os.access(output, os.R_OK):
  91. parser_in.error(f'{option_string} {output} not accessible')
  92. if os.path.isfile(output):
  93. parser_in.error(f'{option_string} {output} is a file')
  94. setattr(namespace, self.dest, output)
  95. class ProjectPathAction(argparse.Action):
  96. """Project directory action class definition."""
  97. def __call__(self, parser_in, namespace, values, option_string=None):
  98. """
  99. Inherited __call__ method from argparse.Action.
  100. Args:
  101. parser_in (ArgumentParser): Passed-in argument parser.
  102. namespace (Namespace): Namespace object to hold arguments.
  103. values (object): Argument values with type depending on argument definition.
  104. option_string (str): Optional string for specific argument name. Default: None.
  105. """
  106. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  107. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  108. if not os.path.exists(outfile_dir):
  109. parser_in.error(f'{option_string} {outfile_dir} not exists')
  110. if not os.path.isdir(outfile_dir):
  111. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  112. setattr(namespace, self.dest, outfile_dir)
  113. class InFileAction(argparse.Action):
  114. """Input File action class definition."""
  115. def __call__(self, parser_in, namespace, values, option_string=None):
  116. """
  117. Inherited __call__ method from argparse.Action.
  118. Args:
  119. parser_in (ArgumentParser): Passed-in argument parser.
  120. namespace (Namespace): Namespace object to hold arguments.
  121. values (object): Argument values with type depending on argument definition.
  122. option_string (str): Optional string for specific argument name. Default: None.
  123. """
  124. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  125. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  126. if not os.path.exists(outfile_dir):
  127. parser_in.error(f'{option_string} {outfile_dir} not exists')
  128. if not os.path.isfile(outfile_dir):
  129. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  130. if not os.path.basename(outfile_dir).endswith("py"):
  131. parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')
  132. setattr(namespace, self.dest, outfile_dir)
  133. class ModelFileAction(argparse.Action):
  134. """Model File action class definition."""
  135. def __call__(self, parser_in, namespace, values, option_string=None):
  136. """
  137. Inherited __call__ method from argparse.Action.
  138. Args:
  139. parser_in (ArgumentParser): Passed-in argument parser.
  140. namespace (Namespace): Namespace object to hold arguments.
  141. values (object): Argument values with type depending on argument definition.
  142. option_string (str): Optional string for specific argument name. Default: None.
  143. """
  144. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  145. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  146. if not os.path.exists(outfile_dir):
  147. parser_in.error(f'{option_string} {outfile_dir} not exists')
  148. if not os.path.isfile(outfile_dir):
  149. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  150. frame_type = get_framework_type(outfile_dir)
  151. if frame_type == FrameworkType.UNKNOWN.value:
  152. parser_in.error(f'{option_string} {outfile_dir} should be an valid '
  153. f'TensorFlow pb or PyTorch pth model file')
  154. setattr(namespace, self.dest, outfile_dir)
  155. class LogFileAction(argparse.Action):
  156. """Log file action class definition."""
  157. def __call__(self, parser_in, namespace, values, option_string=None):
  158. """
  159. Inherited __call__ method from FileDirAction.
  160. Args:
  161. parser_in (ArgumentParser): Passed-in argument parser.
  162. namespace (Namespace): Namespace object to hold arguments.
  163. values (object): Argument values with type depending on argument definition.
  164. option_string (str): Optional string for specific argument name. Default: None.
  165. """
  166. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  167. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  168. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  169. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  170. setattr(namespace, self.dest, outfile_dir)
  171. class ShapeAction(argparse.Action):
  172. """Shape action class definition."""
  173. def __call__(self, parser_in, namespace, values, option_string=None):
  174. """
  175. Inherited __call__ method from FileDirAction.
  176. Args:
  177. parser_in (ArgumentParser): Passed-in argument parser.
  178. namespace (Namespace): Namespace object to hold arguments.
  179. values (object): Argument values with type depending on argument definition.
  180. option_string (str): Optional string for specific argument name. Default: None.
  181. """
  182. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  183. in_shape = None
  184. shape_str = values
  185. shape_list = shape_str.split(':')
  186. if not len(shape_list) == EXPECTED_NUMBER:
  187. parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")
  188. try:
  189. in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
  190. except ValueError:
  191. parser_in.error(
  192. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  193. setattr(namespace, self.dest, in_shape)
  194. class NodeAction(argparse.Action):
  195. """Node action class definition."""
  196. def __call__(self, parser_in, namespace, values, option_string=None):
  197. """
  198. Inherited __call__ method from FileDirAction.
  199. Args:
  200. parser_in (ArgumentParser): Passed-in argument parser.
  201. namespace (Namespace): Namespace object to hold arguments.
  202. values (object): Argument values with type depending on argument definition.
  203. option_string (str): Optional string for specific argument name. Default: None.
  204. """
  205. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  206. node_str = values
  207. if len(node_str) > ARGUMENT_LENGTH_LIMIT:
  208. parser_in.error(
  209. f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  210. )
  211. node_list = node_str.split(',')
  212. if not len(node_list) == EXPECTED_NUMBER:
  213. parser_in.error(f"Only support one {option_string} now, but get {len(node_list)}.")
  214. setattr(namespace, self.dest, node_str)
  215. parser = argparse.ArgumentParser(
  216. prog='mindconverter',
  217. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
  218. allow_abbrev=False)
  219. parser.add_argument(
  220. '--version',
  221. action='version',
  222. version='%(prog)s ({})'.format(mindinsight.__version__))
  223. parser.add_argument(
  224. '--in_file',
  225. type=str,
  226. action=InFileAction,
  227. required=False,
  228. default=None,
  229. help="""
  230. Specify path for script file to use AST schema to
  231. do script conversation.
  232. """)
  233. parser.add_argument(
  234. '--model_file',
  235. type=str,
  236. action=ModelFileAction,
  237. required=False,
  238. help="""
  239. PyTorch(.pth), Tensorflow(.pb) or ONNX(.onnx) model file path
  240. is expected to do script generation based on graph schema. When
  241. `--in_file` and `--model_file` are both provided,
  242. use AST schema as default.
  243. """)
  244. parser.add_argument(
  245. '--shape',
  246. type=str,
  247. action=ShapeAction,
  248. default=None,
  249. required=False,
  250. help="""
  251. Optional, expected input tensor shape of
  252. `--model_file`. It's required when use graph based
  253. schema.
  254. Usage: --shape 1,3,244,244
  255. """)
  256. parser.add_argument(
  257. '--input_nodes',
  258. type=str,
  259. action=NodeAction,
  260. default=None,
  261. required=False,
  262. help="""
  263. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
  264. Usage: --input_nodes input_1:0,input_2:0
  265. """)
  266. parser.add_argument(
  267. '--output_nodes',
  268. type=str,
  269. action=NodeAction,
  270. default=None,
  271. required=False,
  272. help="""
  273. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
  274. Usage: --output_nodes output_1:0,output_2:0
  275. """)
  276. parser.add_argument(
  277. '--output',
  278. type=str,
  279. action=OutputDirAction,
  280. default=os.path.join(os.getcwd(), 'output'),
  281. help="""
  282. Optional, specify path for converted script file
  283. directory. Default output directory is `output` folder
  284. in the current working directory.
  285. """)
  286. parser.add_argument(
  287. '--report',
  288. type=str,
  289. action=LogFileAction,
  290. default=None,
  291. help="""
  292. Optional, specify report directory. Default is
  293. converted script directory.
  294. """)
  295. parser.add_argument(
  296. '--project_path',
  297. type=str,
  298. action=ProjectPathAction,
  299. required=False,
  300. default=None,
  301. help="""
  302. Optional, PyTorch scripts project path. If PyTorch
  303. project is not in PYTHONPATH, please assign
  304. `--project_path` when use graph based schema.
  305. Usage: --project_path ~/script_file/
  306. """)
  307. def cli_entry():
  308. """Entry point for mindconverter CLI."""
  309. permissions = os.R_OK | os.W_OK | os.X_OK
  310. os.umask(permissions << 3 | permissions)
  311. argv = sys.argv[1:]
  312. if not argv:
  313. argv = ['-h']
  314. args = parser.parse_args(argv)
  315. else:
  316. args = parser.parse_args()
  317. mode = permissions << 6
  318. os.makedirs(args.output, mode=mode, exist_ok=True)
  319. if args.report is None:
  320. args.report = args.output
  321. os.makedirs(args.report, mode=mode, exist_ok=True)
  322. _run(args.in_file, args.model_file,
  323. args.shape,
  324. args.input_nodes, args.output_nodes,
  325. args.output, args.report,
  326. args.project_path)
  327. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  328. """
  329. Run converter command.
  330. Args:
  331. in_files (str): The file path or directory to convert.
  332. model_file(str): The pytorch .pth to convert on graph based schema.
  333. shape(list): The input tensor shape of module_file.
  334. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  335. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  336. out_dir (str): The output directory to save converted file.
  337. report (str): The report file path.
  338. project_path(str): Pytorch scripts project path.
  339. """
  340. if in_files:
  341. files_config = {
  342. 'root_path': in_files,
  343. 'in_files': [],
  344. 'outfile_dir': out_dir,
  345. 'report_dir': report if report else out_dir
  346. }
  347. if os.path.isfile(in_files):
  348. files_config['root_path'] = os.path.dirname(in_files)
  349. files_config['in_files'] = [in_files]
  350. else:
  351. for root_dir, _, files in os.walk(in_files):
  352. for file in files:
  353. files_config['in_files'].append(os.path.join(root_dir, file))
  354. main(files_config)
  355. log_console.info("\n")
  356. log_console.info("MindConverter: conversion is completed.")
  357. log_console.info("\n")
  358. elif model_file:
  359. file_config = {
  360. 'model_file': model_file,
  361. 'shape': shape if shape else [],
  362. 'input_nodes': input_nodes,
  363. 'output_nodes': output_nodes,
  364. 'outfile_dir': out_dir,
  365. 'report_dir': report if report else out_dir
  366. }
  367. if project_path:
  368. paths = sys.path
  369. if project_path not in paths:
  370. sys.path.append(project_path)
  371. main_graph_base_converter(file_config)
  372. log_console.info("\n")
  373. log_console.info("MindConverter: conversion is completed.")
  374. log_console.info("\n")
  375. else:
  376. error_msg = "`--in_file` and `--model_file` should be set at least one."
  377. error = FileNotFoundError(error_msg)
  378. log.error(str(error))
  379. log.exception(error)
  380. raise error