You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER
  22. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  23. from mindinsight.mindconverter.common.log import logger as log
  24. class ArgsCheck:
  25. """Args check."""
  26. @staticmethod
  27. def check_repeated(namespace, dest, default, option_string, parser_in):
  28. """Check repeated."""
  29. if getattr(namespace, dest, default) is not default:
  30. parser_in.error(f'Parameter `{option_string}` is set repeatedly.')
  31. class FileDirAction(argparse.Action):
  32. """File directory action class definition."""
  33. @staticmethod
  34. def check_path(parser_in, values, option_string=None):
  35. """
  36. Check argument for file path.
  37. Args:
  38. parser_in (ArgumentParser): Passed-in argument parser.
  39. values (object): Argument values with type depending on argument definition.
  40. option_string (str): Optional string for specific argument name. Default: None.
  41. """
  42. outfile = values
  43. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  44. parser_in.error(
  45. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  46. if outfile.startswith('~'):
  47. outfile = os.path.realpath(os.path.expanduser(outfile))
  48. if not outfile.startswith('/'):
  49. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  50. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  51. parser_in.error(f'{option_string} {outfile} not accessible')
  52. return outfile
  53. def __call__(self, parser_in, namespace, values, option_string=None):
  54. """
  55. Inherited __call__ method from argparse.Action.
  56. Args:
  57. parser_in (ArgumentParser): Passed-in argument parser.
  58. namespace (Namespace): Namespace object to hold arguments.
  59. values (object): Argument values with type depending on argument definition.
  60. option_string (str): Optional string for specific argument name. Default: None.
  61. """
  62. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  63. outfile_dir = self.check_path(parser_in, values, option_string)
  64. if os.path.isfile(outfile_dir):
  65. parser_in.error(f'{option_string} {outfile_dir} is a file')
  66. setattr(namespace, self.dest, outfile_dir)
  67. class OutputDirAction(argparse.Action):
  68. """File directory action class definition."""
  69. def __call__(self, parser_in, namespace, values, option_string=None):
  70. """
  71. Inherited __call__ method from argparse.Action.
  72. Args:
  73. parser_in (ArgumentParser): Passed-in argument parser.
  74. namespace (Namespace): Namespace object to hold arguments.
  75. values (object): Argument values with type depending on argument definition.
  76. option_string (str): Optional string for specific argument name. Default: None.
  77. """
  78. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  79. output = values
  80. if len(output) > ARGUMENT_LENGTH_LIMIT:
  81. parser_in.error(
  82. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  83. if output.startswith('~'):
  84. output = os.path.realpath(os.path.expanduser(output))
  85. if not output.startswith('/'):
  86. output = os.path.realpath(os.path.join(os.getcwd(), output))
  87. if os.path.exists(output):
  88. if not os.access(output, os.R_OK):
  89. parser_in.error(f'{option_string} {output} not accessible')
  90. if os.path.isfile(output):
  91. parser_in.error(f'{option_string} {output} is a file')
  92. setattr(namespace, self.dest, output)
  93. class ProjectPathAction(argparse.Action):
  94. """Project directory action class definition."""
  95. def __call__(self, parser_in, namespace, values, option_string=None):
  96. """
  97. Inherited __call__ method from argparse.Action.
  98. Args:
  99. parser_in (ArgumentParser): Passed-in argument parser.
  100. namespace (Namespace): Namespace object to hold arguments.
  101. values (object): Argument values with type depending on argument definition.
  102. option_string (str): Optional string for specific argument name. Default: None.
  103. """
  104. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  105. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  106. if not os.path.exists(outfile_dir):
  107. parser_in.error(f'{option_string} {outfile_dir} not exists')
  108. if not os.path.isdir(outfile_dir):
  109. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  110. setattr(namespace, self.dest, outfile_dir)
  111. class InFileAction(argparse.Action):
  112. """Input File action class definition."""
  113. def __call__(self, parser_in, namespace, values, option_string=None):
  114. """
  115. Inherited __call__ method from argparse.Action.
  116. Args:
  117. parser_in (ArgumentParser): Passed-in argument parser.
  118. namespace (Namespace): Namespace object to hold arguments.
  119. values (object): Argument values with type depending on argument definition.
  120. option_string (str): Optional string for specific argument name. Default: None.
  121. """
  122. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  123. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  124. if not os.path.exists(outfile_dir):
  125. parser_in.error(f'{option_string} {outfile_dir} not exists')
  126. if not os.path.isfile(outfile_dir):
  127. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  128. if not os.path.basename(outfile_dir).endswith("py"):
  129. parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')
  130. setattr(namespace, self.dest, outfile_dir)
  131. class ModelFileAction(argparse.Action):
  132. """Model File action class definition."""
  133. def __call__(self, parser_in, namespace, values, option_string=None):
  134. """
  135. Inherited __call__ method from argparse.Action.
  136. Args:
  137. parser_in (ArgumentParser): Passed-in argument parser.
  138. namespace (Namespace): Namespace object to hold arguments.
  139. values (object): Argument values with type depending on argument definition.
  140. option_string (str): Optional string for specific argument name. Default: None.
  141. """
  142. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  143. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  144. if not os.path.exists(outfile_dir):
  145. parser_in.error(f'{option_string} {outfile_dir} not exists')
  146. if not os.path.isfile(outfile_dir):
  147. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  148. setattr(namespace, self.dest, outfile_dir)
  149. class LogFileAction(argparse.Action):
  150. """Log file action class definition."""
  151. def __call__(self, parser_in, namespace, values, option_string=None):
  152. """
  153. Inherited __call__ method from FileDirAction.
  154. Args:
  155. parser_in (ArgumentParser): Passed-in argument parser.
  156. namespace (Namespace): Namespace object to hold arguments.
  157. values (object): Argument values with type depending on argument definition.
  158. option_string (str): Optional string for specific argument name. Default: None.
  159. """
  160. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  161. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  162. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  163. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  164. setattr(namespace, self.dest, outfile_dir)
  165. class ShapeAction(argparse.Action):
  166. """Shape action class definition."""
  167. def __call__(self, parser_in, namespace, values, option_string=None):
  168. """
  169. Inherited __call__ method from FileDirAction.
  170. Args:
  171. parser_in (ArgumentParser): Passed-in argument parser.
  172. namespace (Namespace): Namespace object to hold arguments.
  173. values (object): Argument values with type depending on argument definition.
  174. option_string (str): Optional string for specific argument name. Default: None.
  175. """
  176. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  177. in_shape = None
  178. shape_str = values
  179. shape_list = shape_str.split(':')
  180. if not len(shape_list) == EXPECTED_NUMBER:
  181. parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")
  182. try:
  183. in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
  184. except ValueError:
  185. parser_in.error(
  186. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  187. setattr(namespace, self.dest, in_shape)
  188. class NodeAction(argparse.Action):
  189. """Node action class definition."""
  190. def __call__(self, parser_in, namespace, values, option_string=None):
  191. """
  192. Inherited __call__ method from FileDirAction.
  193. Args:
  194. parser_in (ArgumentParser): Passed-in argument parser.
  195. namespace (Namespace): Namespace object to hold arguments.
  196. values (object): Argument values with type depending on argument definition.
  197. option_string (str): Optional string for specific argument name. Default: None.
  198. """
  199. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  200. node_str = values
  201. if len(node_str) > ARGUMENT_LENGTH_LIMIT:
  202. parser_in.error(
  203. f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  204. )
  205. node_list = node_str.split(',')
  206. if not len(node_list) == EXPECTED_NUMBER:
  207. parser_in.error(f"Only support one {option_string} now, but get {len(node_list)}.")
  208. setattr(namespace, self.dest, node_str)
  209. parser = argparse.ArgumentParser(
  210. prog='mindconverter',
  211. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
  212. allow_abbrev=False)
  213. parser.add_argument(
  214. '--version',
  215. action='version',
  216. version='%(prog)s ({})'.format(mindinsight.__version__))
  217. parser.add_argument(
  218. '--in_file',
  219. type=str,
  220. action=InFileAction,
  221. required=False,
  222. default=None,
  223. help="""
  224. Specify path for script file to use AST schema to
  225. do script conversation.
  226. """)
  227. parser.add_argument(
  228. '--model_file',
  229. type=str,
  230. action=ModelFileAction,
  231. required=False,
  232. help="""
  233. PyTorch .pth or Tensorflow .pb model file path to use graph
  234. based schema to do script generation. When
  235. `--in_file` and `--model_file` are both provided,
  236. use AST schema as default.
  237. """)
  238. parser.add_argument(
  239. '--shape',
  240. type=str,
  241. action=ShapeAction,
  242. default=None,
  243. required=False,
  244. help="""
  245. Optional, expected input tensor shape of
  246. `--model_file`. It's required when use graph based
  247. schema.
  248. Usage: --shape 1,3,244,244
  249. """)
  250. parser.add_argument(
  251. '--input_nodes',
  252. type=str,
  253. action=NodeAction,
  254. default=None,
  255. required=False,
  256. help="""
  257. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
  258. Usage: --input_nodes input_1:0,input_2:0
  259. """)
  260. parser.add_argument(
  261. '--output_nodes',
  262. type=str,
  263. action=NodeAction,
  264. default=None,
  265. required=False,
  266. help="""
  267. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
  268. Usage: --output_nodes output_1:0,output_2:0
  269. """)
  270. parser.add_argument(
  271. '--output',
  272. type=str,
  273. action=OutputDirAction,
  274. default=os.path.join(os.getcwd(), 'output'),
  275. help="""
  276. Optional, specify path for converted script file
  277. directory. Default output directory is `output` folder
  278. in the current working directory.
  279. """)
  280. parser.add_argument(
  281. '--report',
  282. type=str,
  283. action=LogFileAction,
  284. default=None,
  285. help="""
  286. Optional, specify report directory. Default is
  287. converted script directory.
  288. """)
  289. parser.add_argument(
  290. '--project_path',
  291. type=str,
  292. action=ProjectPathAction,
  293. required=False,
  294. default=None,
  295. help="""
  296. Optional, PyTorch scripts project path. If PyTorch
  297. project is not in PYTHONPATH, please assign
  298. `--project_path` when use graph based schema.
  299. Usage: --project_path ~/script_file/
  300. """)
  301. def cli_entry():
  302. """Entry point for mindconverter CLI."""
  303. permissions = os.R_OK | os.W_OK | os.X_OK
  304. os.umask(permissions << 3 | permissions)
  305. argv = sys.argv[1:]
  306. if not argv:
  307. argv = ['-h']
  308. args = parser.parse_args(argv)
  309. else:
  310. args = parser.parse_args()
  311. mode = permissions << 6
  312. os.makedirs(args.output, mode=mode, exist_ok=True)
  313. if args.report is None:
  314. args.report = args.output
  315. os.makedirs(args.report, mode=mode, exist_ok=True)
  316. _run(args.in_file, args.model_file,
  317. args.shape,
  318. args.input_nodes, args.output_nodes,
  319. args.output, args.report,
  320. args.project_path)
  321. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  322. """
  323. Run converter command.
  324. Args:
  325. in_files (str): The file path or directory to convert.
  326. model_file(str): The pytorch .pth to convert on graph based schema.
  327. shape(list): The input tensor shape of module_file.
  328. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  329. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  330. out_dir (str): The output directory to save converted file.
  331. report (str): The report file path.
  332. project_path(str): Pytorch scripts project path.
  333. """
  334. if in_files:
  335. files_config = {
  336. 'root_path': in_files,
  337. 'in_files': [],
  338. 'outfile_dir': out_dir,
  339. 'report_dir': report if report else out_dir
  340. }
  341. if os.path.isfile(in_files):
  342. files_config['root_path'] = os.path.dirname(in_files)
  343. files_config['in_files'] = [in_files]
  344. else:
  345. for root_dir, _, files in os.walk(in_files):
  346. for file in files:
  347. files_config['in_files'].append(os.path.join(root_dir, file))
  348. main(files_config)
  349. elif model_file:
  350. file_config = {
  351. 'model_file': model_file,
  352. 'shape': shape if shape else [],
  353. 'input_nodes': input_nodes,
  354. 'output_nodes': output_nodes,
  355. 'outfile_dir': out_dir,
  356. 'report_dir': report if report else out_dir
  357. }
  358. if project_path:
  359. paths = sys.path
  360. if project_path not in paths:
  361. sys.path.append(project_path)
  362. main_graph_base_converter(file_config)
  363. else:
  364. error_msg = "`--in_file` and `--model_file` should be set at least one."
  365. error = FileNotFoundError(error_msg)
  366. log.error(str(error))
  367. log.exception(error)
  368. raise error