You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER
  22. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  23. from mindinsight.mindconverter.common.log import logger as log
  24. class ArgsCheck:
  25. """Args check."""
  26. @staticmethod
  27. def check_repeated(namespace, dest, default, option_string, parser_in):
  28. """Check repeated."""
  29. if getattr(namespace, dest, default) is not default:
  30. parser_in.error(f'Parameter `{option_string}` is set repeatedly.')
  31. class FileDirAction(argparse.Action):
  32. """File directory action class definition."""
  33. @staticmethod
  34. def check_path(parser_in, values, option_string=None):
  35. """
  36. Check argument for file path.
  37. Args:
  38. parser_in (ArgumentParser): Passed-in argument parser.
  39. values (object): Argument values with type depending on argument definition.
  40. option_string (str): Optional string for specific argument name. Default: None.
  41. """
  42. outfile = values
  43. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  44. parser_in.error(
  45. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  46. if outfile.startswith('~'):
  47. outfile = os.path.realpath(os.path.expanduser(outfile))
  48. if not outfile.startswith('/'):
  49. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  50. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  51. parser_in.error(f'{option_string} {outfile} not accessible')
  52. return outfile
  53. def __call__(self, parser_in, namespace, values, option_string=None):
  54. """
  55. Inherited __call__ method from argparse.Action.
  56. Args:
  57. parser_in (ArgumentParser): Passed-in argument parser.
  58. namespace (Namespace): Namespace object to hold arguments.
  59. values (object): Argument values with type depending on argument definition.
  60. option_string (str): Optional string for specific argument name. Default: None.
  61. """
  62. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  63. outfile_dir = self.check_path(parser_in, values, option_string)
  64. if os.path.isfile(outfile_dir):
  65. parser_in.error(f'{option_string} {outfile_dir} is a file')
  66. setattr(namespace, self.dest, outfile_dir)
  67. class OutputDirAction(argparse.Action):
  68. """File directory action class definition."""
  69. def __call__(self, parser_in, namespace, values, option_string=None):
  70. """
  71. Inherited __call__ method from argparse.Action.
  72. Args:
  73. parser_in (ArgumentParser): Passed-in argument parser.
  74. namespace (Namespace): Namespace object to hold arguments.
  75. values (object): Argument values with type depending on argument definition.
  76. option_string (str): Optional string for specific argument name. Default: None.
  77. """
  78. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  79. output = values
  80. if len(output) > ARGUMENT_LENGTH_LIMIT:
  81. parser_in.error(
  82. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  83. if output.startswith('~'):
  84. output = os.path.realpath(os.path.expanduser(output))
  85. if not output.startswith('/'):
  86. output = os.path.realpath(os.path.join(os.getcwd(), output))
  87. if os.path.exists(output):
  88. if not os.access(output, os.R_OK):
  89. parser_in.error(f'{option_string} {output} not accessible')
  90. if os.path.isfile(output):
  91. parser_in.error(f'{option_string} {output} is a file')
  92. setattr(namespace, self.dest, output)
  93. class ProjectPathAction(argparse.Action):
  94. """Project directory action class definition."""
  95. def __call__(self, parser_in, namespace, values, option_string=None):
  96. """
  97. Inherited __call__ method from argparse.Action.
  98. Args:
  99. parser_in (ArgumentParser): Passed-in argument parser.
  100. namespace (Namespace): Namespace object to hold arguments.
  101. values (object): Argument values with type depending on argument definition.
  102. option_string (str): Optional string for specific argument name. Default: None.
  103. """
  104. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  105. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  106. if not os.path.exists(outfile_dir):
  107. parser_in.error(f'{option_string} {outfile_dir} not exists')
  108. if not os.path.isdir(outfile_dir):
  109. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  110. setattr(namespace, self.dest, outfile_dir)
  111. class InFileAction(argparse.Action):
  112. """Input File action class definition."""
  113. def __call__(self, parser_in, namespace, values, option_string=None):
  114. """
  115. Inherited __call__ method from argparse.Action.
  116. Args:
  117. parser_in (ArgumentParser): Passed-in argument parser.
  118. namespace (Namespace): Namespace object to hold arguments.
  119. values (object): Argument values with type depending on argument definition.
  120. option_string (str): Optional string for specific argument name. Default: None.
  121. """
  122. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  123. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  124. if not os.path.exists(outfile_dir):
  125. parser_in.error(f'{option_string} {outfile_dir} not exists')
  126. if not os.path.isfile(outfile_dir):
  127. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  128. setattr(namespace, self.dest, outfile_dir)
  129. class ModelFileAction(argparse.Action):
  130. """Model File action class definition."""
  131. def __call__(self, parser_in, namespace, values, option_string=None):
  132. """
  133. Inherited __call__ method from argparse.Action.
  134. Args:
  135. parser_in (ArgumentParser): Passed-in argument parser.
  136. namespace (Namespace): Namespace object to hold arguments.
  137. values (object): Argument values with type depending on argument definition.
  138. option_string (str): Optional string for specific argument name. Default: None.
  139. """
  140. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  141. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  142. if not os.path.exists(outfile_dir):
  143. parser_in.error(f'{option_string} {outfile_dir} not exists')
  144. if not os.path.isfile(outfile_dir):
  145. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  146. setattr(namespace, self.dest, outfile_dir)
  147. class LogFileAction(argparse.Action):
  148. """Log file action class definition."""
  149. def __call__(self, parser_in, namespace, values, option_string=None):
  150. """
  151. Inherited __call__ method from FileDirAction.
  152. Args:
  153. parser_in (ArgumentParser): Passed-in argument parser.
  154. namespace (Namespace): Namespace object to hold arguments.
  155. values (object): Argument values with type depending on argument definition.
  156. option_string (str): Optional string for specific argument name. Default: None.
  157. """
  158. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  159. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  160. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  161. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  162. setattr(namespace, self.dest, outfile_dir)
  163. class ShapeAction(argparse.Action):
  164. """Shape action class definition."""
  165. def __call__(self, parser_in, namespace, values, option_string=None):
  166. """
  167. Inherited __call__ method from FileDirAction.
  168. Args:
  169. parser_in (ArgumentParser): Passed-in argument parser.
  170. namespace (Namespace): Namespace object to hold arguments.
  171. values (object): Argument values with type depending on argument definition.
  172. option_string (str): Optional string for specific argument name. Default: None.
  173. """
  174. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  175. in_shape = None
  176. shape_str = values
  177. shape_list = shape_str.split(':')
  178. if not len(shape_list) == EXPECTED_NUMBER:
  179. parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")
  180. try:
  181. in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
  182. except ValueError:
  183. parser_in.error(
  184. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  185. setattr(namespace, self.dest, in_shape)
  186. class NodeAction(argparse.Action):
  187. """Node action class definition."""
  188. def __call__(self, parser_in, namespace, values, option_string=None):
  189. """
  190. Inherited __call__ method from FileDirAction.
  191. Args:
  192. parser_in (ArgumentParser): Passed-in argument parser.
  193. namespace (Namespace): Namespace object to hold arguments.
  194. values (object): Argument values with type depending on argument definition.
  195. option_string (str): Optional string for specific argument name. Default: None.
  196. """
  197. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  198. node_str = values
  199. if len(node_str) > ARGUMENT_LENGTH_LIMIT:
  200. parser_in.error(
  201. f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  202. )
  203. node_list = node_str.split(',')
  204. if not len(node_list) == EXPECTED_NUMBER:
  205. parser_in.error(f"Only support one {option_string} now, but get {len(node_list)}.")
  206. setattr(namespace, self.dest, node_str)
  207. parser = argparse.ArgumentParser(
  208. prog='mindconverter',
  209. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
  210. allow_abbrev=False)
  211. parser.add_argument(
  212. '--version',
  213. action='version',
  214. version='%(prog)s ({})'.format(mindinsight.__version__))
  215. parser.add_argument(
  216. '--in_file',
  217. type=str,
  218. action=InFileAction,
  219. required=False,
  220. default=None,
  221. help="""
  222. Specify path for script file to use AST schema to
  223. do script conversation.
  224. """)
  225. parser.add_argument(
  226. '--model_file',
  227. type=str,
  228. action=ModelFileAction,
  229. required=False,
  230. help="""
  231. PyTorch .pth or Tensorflow .pb model file path to use graph
  232. based schema to do script generation. When
  233. `--in_file` and `--model_file` are both provided,
  234. use AST schema as default.
  235. """)
  236. parser.add_argument(
  237. '--shape',
  238. type=str,
  239. action=ShapeAction,
  240. default=None,
  241. required=False,
  242. help="""
  243. Optional, expected input tensor shape of
  244. `--model_file`. It's required when use graph based
  245. schema.
  246. Usage: --shape 1,3,244,244
  247. """)
  248. parser.add_argument(
  249. '--input_nodes',
  250. type=str,
  251. action=NodeAction,
  252. default=None,
  253. required=False,
  254. help="""
  255. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
  256. Usage: --input_nodes input_1:0,input_2:0
  257. """)
  258. parser.add_argument(
  259. '--output_nodes',
  260. type=str,
  261. action=NodeAction,
  262. default=None,
  263. required=False,
  264. help="""
  265. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
  266. Usage: --output_nodes output_1:0,output_2:0
  267. """)
  268. parser.add_argument(
  269. '--output',
  270. type=str,
  271. action=OutputDirAction,
  272. default=os.path.join(os.getcwd(), 'output'),
  273. help="""
  274. Optional, specify path for converted script file
  275. directory. Default output directory is `output` folder
  276. in the current working directory.
  277. """)
  278. parser.add_argument(
  279. '--report',
  280. type=str,
  281. action=LogFileAction,
  282. default=None,
  283. help="""
  284. Optional, specify report directory. Default is
  285. converted script directory.
  286. """)
  287. parser.add_argument(
  288. '--project_path',
  289. type=str,
  290. action=ProjectPathAction,
  291. required=False,
  292. default=None,
  293. help="""
  294. Optional, PyTorch scripts project path. If PyTorch
  295. project is not in PYTHONPATH, please assign
  296. `--project_path` when use graph based schema.
  297. Usage: --project_path ~/script_file/
  298. """)
  299. def cli_entry():
  300. """Entry point for mindconverter CLI."""
  301. permissions = os.R_OK | os.W_OK | os.X_OK
  302. os.umask(permissions << 3 | permissions)
  303. argv = sys.argv[1:]
  304. if not argv:
  305. argv = ['-h']
  306. args = parser.parse_args(argv)
  307. else:
  308. args = parser.parse_args()
  309. mode = permissions << 6
  310. os.makedirs(args.output, mode=mode, exist_ok=True)
  311. if args.report is None:
  312. args.report = args.output
  313. os.makedirs(args.report, mode=mode, exist_ok=True)
  314. _run(args.in_file, args.model_file,
  315. args.shape,
  316. args.input_nodes, args.output_nodes,
  317. args.output, args.report,
  318. args.project_path)
  319. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  320. """
  321. Run converter command.
  322. Args:
  323. in_files (str): The file path or directory to convert.
  324. model_file(str): The pytorch .pth to convert on graph based schema.
  325. shape(list): The input tensor shape of module_file.
  326. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  327. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  328. out_dir (str): The output directory to save converted file.
  329. report (str): The report file path.
  330. project_path(str): Pytorch scripts project path.
  331. """
  332. if in_files:
  333. files_config = {
  334. 'root_path': in_files,
  335. 'in_files': [],
  336. 'outfile_dir': out_dir,
  337. 'report_dir': report if report else out_dir
  338. }
  339. if os.path.isfile(in_files):
  340. files_config['root_path'] = os.path.dirname(in_files)
  341. files_config['in_files'] = [in_files]
  342. else:
  343. for root_dir, _, files in os.walk(in_files):
  344. for file in files:
  345. files_config['in_files'].append(os.path.join(root_dir, file))
  346. main(files_config)
  347. elif model_file:
  348. file_config = {
  349. 'model_file': model_file,
  350. 'shape': shape if shape else [],
  351. 'input_nodes': input_nodes,
  352. 'output_nodes': output_nodes,
  353. 'outfile_dir': out_dir,
  354. 'report_dir': report if report else out_dir
  355. }
  356. if project_path:
  357. paths = sys.path
  358. if project_path not in paths:
  359. sys.path.append(project_path)
  360. main_graph_base_converter(file_config)
  361. else:
  362. error_msg = "`--in_file` and `--model_file` should be set at least one."
  363. error = FileNotFoundError(error_msg)
  364. log.error(str(error))
  365. log.exception(error)
  366. raise error