You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER
  22. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  23. from mindinsight.mindconverter.common.log import logger as log
  24. class FileDirAction(argparse.Action):
  25. """File directory action class definition."""
  26. @staticmethod
  27. def check_path(parser_in, values, option_string=None):
  28. """
  29. Check argument for file path.
  30. Args:
  31. parser_in (ArgumentParser): Passed-in argument parser.
  32. values (object): Argument values with type depending on argument definition.
  33. option_string (str): Optional string for specific argument name. Default: None.
  34. """
  35. outfile = values
  36. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  37. parser_in.error(
  38. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  39. if outfile.startswith('~'):
  40. outfile = os.path.realpath(os.path.expanduser(outfile))
  41. if not outfile.startswith('/'):
  42. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  43. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  44. parser_in.error(f'{option_string} {outfile} not accessible')
  45. return outfile
  46. def __call__(self, parser_in, namespace, values, option_string=None):
  47. """
  48. Inherited __call__ method from argparse.Action.
  49. Args:
  50. parser_in (ArgumentParser): Passed-in argument parser.
  51. namespace (Namespace): Namespace object to hold arguments.
  52. values (object): Argument values with type depending on argument definition.
  53. option_string (str): Optional string for specific argument name. Default: None.
  54. """
  55. outfile_dir = self.check_path(parser_in, values, option_string)
  56. if os.path.isfile(outfile_dir):
  57. parser_in.error(f'{option_string} {outfile_dir} is a file')
  58. setattr(namespace, self.dest, outfile_dir)
  59. class OutputDirAction(argparse.Action):
  60. """File directory action class definition."""
  61. def __call__(self, parser_in, namespace, values, option_string=None):
  62. """
  63. Inherited __call__ method from argparse.Action.
  64. Args:
  65. parser_in (ArgumentParser): Passed-in argument parser.
  66. namespace (Namespace): Namespace object to hold arguments.
  67. values (object): Argument values with type depending on argument definition.
  68. option_string (str): Optional string for specific argument name. Default: None.
  69. """
  70. output = values
  71. if len(output) > ARGUMENT_LENGTH_LIMIT:
  72. parser_in.error(
  73. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  74. if output.startswith('~'):
  75. output = os.path.realpath(os.path.expanduser(output))
  76. if not output.startswith('/'):
  77. output = os.path.realpath(os.path.join(os.getcwd(), output))
  78. if os.path.exists(output):
  79. if not os.access(output, os.R_OK):
  80. parser_in.error(f'{option_string} {output} not accessible')
  81. if os.path.isfile(output):
  82. parser_in.error(f'{option_string} {output} is a file')
  83. setattr(namespace, self.dest, output)
  84. class ProjectPathAction(argparse.Action):
  85. """Project directory action class definition."""
  86. def __call__(self, parser_in, namespace, values, option_string=None):
  87. """
  88. Inherited __call__ method from argparse.Action.
  89. Args:
  90. parser_in (ArgumentParser): Passed-in argument parser.
  91. namespace (Namespace): Namespace object to hold arguments.
  92. values (object): Argument values with type depending on argument definition.
  93. option_string (str): Optional string for specific argument name. Default: None.
  94. """
  95. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  96. if not os.path.exists(outfile_dir):
  97. parser_in.error(f'{option_string} {outfile_dir} not exists')
  98. if not os.path.isdir(outfile_dir):
  99. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  100. setattr(namespace, self.dest, outfile_dir)
  101. class InFileAction(argparse.Action):
  102. """Input File action class definition."""
  103. def __call__(self, parser_in, namespace, values, option_string=None):
  104. """
  105. Inherited __call__ method from argparse.Action.
  106. Args:
  107. parser_in (ArgumentParser): Passed-in argument parser.
  108. namespace (Namespace): Namespace object to hold arguments.
  109. values (object): Argument values with type depending on argument definition.
  110. option_string (str): Optional string for specific argument name. Default: None.
  111. """
  112. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  113. if not os.path.exists(outfile_dir):
  114. parser_in.error(f'{option_string} {outfile_dir} not exists')
  115. if not os.path.isfile(outfile_dir):
  116. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  117. setattr(namespace, self.dest, outfile_dir)
  118. class ModelFileAction(argparse.Action):
  119. """Model File action class definition."""
  120. def __call__(self, parser_in, namespace, values, option_string=None):
  121. """
  122. Inherited __call__ method from argparse.Action.
  123. Args:
  124. parser_in (ArgumentParser): Passed-in argument parser.
  125. namespace (Namespace): Namespace object to hold arguments.
  126. values (object): Argument values with type depending on argument definition.
  127. option_string (str): Optional string for specific argument name. Default: None.
  128. """
  129. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  130. if not os.path.exists(outfile_dir):
  131. parser_in.error(f'{option_string} {outfile_dir} not exists')
  132. if not os.path.isfile(outfile_dir):
  133. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  134. setattr(namespace, self.dest, outfile_dir)
  135. class LogFileAction(argparse.Action):
  136. """Log file action class definition."""
  137. def __call__(self, parser_in, namespace, values, option_string=None):
  138. """
  139. Inherited __call__ method from FileDirAction.
  140. Args:
  141. parser_in (ArgumentParser): Passed-in argument parser.
  142. namespace (Namespace): Namespace object to hold arguments.
  143. values (object): Argument values with type depending on argument definition.
  144. option_string (str): Optional string for specific argument name. Default: None.
  145. """
  146. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  147. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  148. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  149. setattr(namespace, self.dest, outfile_dir)
  150. class ShapeAction(argparse.Action):
  151. """Shape action class definition."""
  152. def __call__(self, parser_in, namespace, values, option_string=None):
  153. """
  154. Inherited __call__ method from FileDirAction.
  155. Args:
  156. parser_in (ArgumentParser): Passed-in argument parser.
  157. namespace (Namespace): Namespace object to hold arguments.
  158. values (object): Argument values with type depending on argument definition.
  159. option_string (str): Optional string for specific argument name. Default: None.
  160. """
  161. in_shape = None
  162. shape_str = values
  163. shape_list = shape_str.split(':')
  164. if not len(shape_list) == EXPECTED_SHAPE_NUMBER:
  165. parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")
  166. try:
  167. in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
  168. except ValueError:
  169. parser_in.error(
  170. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  171. setattr(namespace, self.dest, in_shape)
  172. class NodeAction(argparse.Action):
  173. """Node action class definition."""
  174. def __call__(self, parser_in, namespace, values, option_string=None):
  175. """
  176. Inherited __call__ method from FileDirAction.
  177. Args:
  178. parser_in (ArgumentParser): Passed-in argument parser.
  179. namespace (Namespace): Namespace object to hold arguments.
  180. values (object): Argument values with type depending on argument definition.
  181. option_string (str): Optional string for specific argument name. Default: None.
  182. """
  183. node_str = values
  184. if len(node_str) > ARGUMENT_LENGTH_LIMIT:
  185. parser_in.error(
  186. f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  187. )
  188. setattr(namespace, self.dest, node_str)
  189. parser = argparse.ArgumentParser(
  190. prog='mindconverter',
  191. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
  192. parser.add_argument(
  193. '--version',
  194. action='version',
  195. version='%(prog)s ({})'.format(mindinsight.__version__))
  196. parser.add_argument(
  197. '--in_file',
  198. type=str,
  199. action=InFileAction,
  200. required=False,
  201. default=None,
  202. help="""
  203. Specify path for script file to use AST schema to
  204. do script conversation.
  205. """)
  206. parser.add_argument(
  207. '--model_file',
  208. type=str,
  209. action=ModelFileAction,
  210. required=False,
  211. help="""
  212. PyTorch .pth or Tensorflow .pb model file path to use graph
  213. based schema to do script generation. When
  214. `--in_file` and `--model_file` are both provided,
  215. use AST schema as default.
  216. """)
  217. parser.add_argument(
  218. '--shape',
  219. type=str,
  220. action=ShapeAction,
  221. default=None,
  222. required=False,
  223. help="""
  224. Optional, expected input tensor shape of
  225. `--model_file`. It's required when use graph based
  226. schema.
  227. Usage: --shape 1,3,244,244
  228. """)
  229. parser.add_argument(
  230. '--input_nodes',
  231. type=str,
  232. action=NodeAction,
  233. default=None,
  234. required=False,
  235. help="""
  236. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
  237. Usage: --input_nodes input_1:0,input_2:0
  238. """)
  239. parser.add_argument(
  240. '--output_nodes',
  241. type=str,
  242. action=NodeAction,
  243. default=None,
  244. required=False,
  245. help="""
  246. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
  247. Usage: --output_nodes output_1:0,output_2:0
  248. """)
  249. parser.add_argument(
  250. '--output',
  251. type=str,
  252. action=OutputDirAction,
  253. default=os.path.join(os.getcwd(), 'output'),
  254. help="""
  255. Optional, specify path for converted script file
  256. directory. Default output directory is `output` folder
  257. in the current working directory.
  258. """)
  259. parser.add_argument(
  260. '--report',
  261. type=str,
  262. action=LogFileAction,
  263. default=None,
  264. help="""
  265. Optional, specify report directory. Default is
  266. converted script directory.
  267. """)
  268. parser.add_argument(
  269. '--project_path',
  270. type=str,
  271. action=ProjectPathAction,
  272. required=False,
  273. default=None,
  274. help="""
  275. Optional, PyTorch scripts project path. If PyTorch
  276. project is not in PYTHONPATH, please assign
  277. `--project_path` when use graph based schema.
  278. Usage: --project_path ~/script_file/
  279. """)
  280. def cli_entry():
  281. """Entry point for mindconverter CLI."""
  282. permissions = os.R_OK | os.W_OK | os.X_OK
  283. os.umask(permissions << 3 | permissions)
  284. argv = sys.argv[1:]
  285. if not argv:
  286. argv = ['-h']
  287. args = parser.parse_args(argv)
  288. else:
  289. args = parser.parse_args()
  290. mode = permissions << 6
  291. os.makedirs(args.output, mode=mode, exist_ok=True)
  292. if args.report is None:
  293. args.report = args.output
  294. os.makedirs(args.report, mode=mode, exist_ok=True)
  295. _run(args.in_file, args.model_file,
  296. args.shape,
  297. args.input_nodes, args.output_nodes,
  298. args.output, args.report,
  299. args.project_path)
  300. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  301. """
  302. Run converter command.
  303. Args:
  304. in_files (str): The file path or directory to convert.
  305. model_file(str): The pytorch .pth to convert on graph based schema.
  306. shape(list): The input tensor shape of module_file.
  307. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  308. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  309. out_dir (str): The output directory to save converted file.
  310. report (str): The report file path.
  311. project_path(str): Pytorch scripts project path.
  312. """
  313. if in_files:
  314. files_config = {
  315. 'root_path': in_files,
  316. 'in_files': [],
  317. 'outfile_dir': out_dir,
  318. 'report_dir': report if report else out_dir
  319. }
  320. if os.path.isfile(in_files):
  321. files_config['root_path'] = os.path.dirname(in_files)
  322. files_config['in_files'] = [in_files]
  323. else:
  324. for root_dir, _, files in os.walk(in_files):
  325. for file in files:
  326. files_config['in_files'].append(os.path.join(root_dir, file))
  327. main(files_config)
  328. elif model_file:
  329. file_config = {
  330. 'model_file': model_file,
  331. 'shape': shape if shape else [],
  332. 'input_nodes': input_nodes,
  333. 'output_nodes': output_nodes,
  334. 'outfile_dir': out_dir,
  335. 'report_dir': report if report else out_dir
  336. }
  337. if project_path:
  338. paths = sys.path
  339. if project_path not in paths:
  340. sys.path.append(project_path)
  341. main_graph_base_converter(file_config)
  342. else:
  343. error_msg = "`--in_file` and `--model_file` should be set at least one."
  344. error = FileNotFoundError(error_msg)
  345. log.error(str(error))
  346. log.exception(error)
  347. raise error