You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER
  22. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  23. from mindinsight.mindconverter.common.log import logger as log
  24. class FileDirAction(argparse.Action):
  25. """File directory action class definition."""
  26. @staticmethod
  27. def check_path(parser_in, values, option_string=None):
  28. """
  29. Check argument for file path.
  30. Args:
  31. parser_in (ArgumentParser): Passed-in argument parser.
  32. values (object): Argument values with type depending on argument definition.
  33. option_string (str): Optional string for specific argument name. Default: None.
  34. """
  35. outfile = values
  36. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  37. parser_in.error(
  38. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  39. if outfile.startswith('~'):
  40. outfile = os.path.realpath(os.path.expanduser(outfile))
  41. if not outfile.startswith('/'):
  42. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  43. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  44. parser_in.error(f'{option_string} {outfile} not accessible')
  45. return outfile
  46. def __call__(self, parser_in, namespace, values, option_string=None):
  47. """
  48. Inherited __call__ method from argparse.Action.
  49. Args:
  50. parser_in (ArgumentParser): Passed-in argument parser.
  51. namespace (Namespace): Namespace object to hold arguments.
  52. values (object): Argument values with type depending on argument definition.
  53. option_string (str): Optional string for specific argument name. Default: None.
  54. """
  55. outfile_dir = self.check_path(parser_in, values, option_string)
  56. if os.path.isfile(outfile_dir):
  57. parser_in.error(f'{option_string} {outfile_dir} is a file')
  58. setattr(namespace, self.dest, outfile_dir)
  59. class OutputDirAction(argparse.Action):
  60. """File directory action class definition."""
  61. def __call__(self, parser_in, namespace, values, option_string=None):
  62. """
  63. Inherited __call__ method from argparse.Action.
  64. Args:
  65. parser_in (ArgumentParser): Passed-in argument parser.
  66. namespace (Namespace): Namespace object to hold arguments.
  67. values (object): Argument values with type depending on argument definition.
  68. option_string (str): Optional string for specific argument name. Default: None.
  69. """
  70. output = values
  71. if output.startswith('~'):
  72. output = os.path.realpath(os.path.expanduser(output))
  73. if not output.startswith('/'):
  74. output = os.path.realpath(os.path.join(os.getcwd(), output))
  75. if os.path.exists(output):
  76. if not os.access(output, os.R_OK):
  77. parser_in.error(f'{option_string} {output} not accessible')
  78. if os.path.isfile(output):
  79. parser_in.error(f'{option_string} {output} is a file')
  80. setattr(namespace, self.dest, output)
  81. class ProjectPathAction(argparse.Action):
  82. """Project directory action class definition."""
  83. def __call__(self, parser_in, namespace, values, option_string=None):
  84. """
  85. Inherited __call__ method from argparse.Action.
  86. Args:
  87. parser_in (ArgumentParser): Passed-in argument parser.
  88. namespace (Namespace): Namespace object to hold arguments.
  89. values (object): Argument values with type depending on argument definition.
  90. option_string (str): Optional string for specific argument name. Default: None.
  91. """
  92. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  93. if not os.path.exists(outfile_dir):
  94. parser_in.error(f'{option_string} {outfile_dir} not exists')
  95. if not os.path.isdir(outfile_dir):
  96. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  97. setattr(namespace, self.dest, outfile_dir)
  98. class InFileAction(argparse.Action):
  99. """Input File action class definition."""
  100. def __call__(self, parser_in, namespace, values, option_string=None):
  101. """
  102. Inherited __call__ method from argparse.Action.
  103. Args:
  104. parser_in (ArgumentParser): Passed-in argument parser.
  105. namespace (Namespace): Namespace object to hold arguments.
  106. values (object): Argument values with type depending on argument definition.
  107. option_string (str): Optional string for specific argument name. Default: None.
  108. """
  109. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  110. if not os.path.exists(outfile_dir):
  111. parser_in.error(f'{option_string} {outfile_dir} not exists')
  112. if not os.path.isfile(outfile_dir):
  113. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  114. setattr(namespace, self.dest, outfile_dir)
  115. class ModelFileAction(argparse.Action):
  116. """Model File action class definition."""
  117. def __call__(self, parser_in, namespace, values, option_string=None):
  118. """
  119. Inherited __call__ method from argparse.Action.
  120. Args:
  121. parser_in (ArgumentParser): Passed-in argument parser.
  122. namespace (Namespace): Namespace object to hold arguments.
  123. values (object): Argument values with type depending on argument definition.
  124. option_string (str): Optional string for specific argument name. Default: None.
  125. """
  126. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  127. if not os.path.exists(outfile_dir):
  128. parser_in.error(f'{option_string} {outfile_dir} not exists')
  129. if not os.path.isfile(outfile_dir):
  130. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  131. setattr(namespace, self.dest, outfile_dir)
  132. class LogFileAction(argparse.Action):
  133. """Log file action class definition."""
  134. def __call__(self, parser_in, namespace, values, option_string=None):
  135. """
  136. Inherited __call__ method from FileDirAction.
  137. Args:
  138. parser_in (ArgumentParser): Passed-in argument parser.
  139. namespace (Namespace): Namespace object to hold arguments.
  140. values (object): Argument values with type depending on argument definition.
  141. option_string (str): Optional string for specific argument name. Default: None.
  142. """
  143. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  144. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  145. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  146. setattr(namespace, self.dest, outfile_dir)
  147. class ShapeAction(argparse.Action):
  148. """Shape action class definition."""
  149. def __call__(self, parser_in, namespace, values, option_string=None):
  150. """
  151. Inherited __call__ method from FileDirAction.
  152. Args:
  153. parser_in (ArgumentParser): Passed-in argument parser.
  154. namespace (Namespace): Namespace object to hold arguments.
  155. values (object): Argument values with type depending on argument definition.
  156. option_string (str): Optional string for specific argument name. Default: None.
  157. """
  158. in_shape = None
  159. shape_str = values
  160. shape_list = shape_str.split(';')
  161. if not len(shape_list) == EXPECTED_SHAPE_NUMBER:
  162. parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")
  163. try:
  164. in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
  165. except ValueError:
  166. parser_in.error(
  167. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  168. setattr(namespace, self.dest, in_shape)
  169. class NodeAction(argparse.Action):
  170. """Node action class definition."""
  171. def __call__(self, parser_in, namespace, values, option_string=None):
  172. """
  173. Inherited __call__ method from FileDirAction.
  174. Args:
  175. parser_in (ArgumentParser): Passed-in argument parser.
  176. namespace (Namespace): Namespace object to hold arguments.
  177. values (object): Argument values with type depending on argument definition.
  178. option_string (str): Optional string for specific argument name. Default: None.
  179. """
  180. node_str = values
  181. if len(node_str) > ARGUMENT_LENGTH_LIMIT:
  182. parser_in.error(
  183. f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  184. )
  185. setattr(namespace, self.dest, node_str)
  186. parser = argparse.ArgumentParser(
  187. prog='mindconverter',
  188. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
  189. parser.add_argument(
  190. '--version',
  191. action='version',
  192. version='%(prog)s ({})'.format(mindinsight.__version__))
  193. parser.add_argument(
  194. '--in_file',
  195. type=str,
  196. action=InFileAction,
  197. required=False,
  198. default=None,
  199. help="""
  200. Specify path for script file to use AST schema to
  201. do script conversation.
  202. """)
  203. parser.add_argument(
  204. '--model_file',
  205. type=str,
  206. action=ModelFileAction,
  207. required=False,
  208. help="""
  209. PyTorch .pth or Tensorflow .pb model file path to use graph
  210. based schema to do script generation. When
  211. `--in_file` and `--model_file` are both provided,
  212. use AST schema as default.
  213. """)
  214. parser.add_argument(
  215. '--shape',
  216. type=str,
  217. action=ShapeAction,
  218. default=None,
  219. required=False,
  220. help="""
  221. Optional, expected input tensor shape of
  222. `--model_file`. It's required when use graph based
  223. schema.
  224. Usage: --shape 1,3,244,244
  225. """)
  226. parser.add_argument(
  227. '--input_nodes',
  228. type=str,
  229. action=NodeAction,
  230. default=None,
  231. required=False,
  232. help="""
  233. Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
  234. Usage: --input_nodes input_1:0,input_2:0
  235. """)
  236. parser.add_argument(
  237. '--output_nodes',
  238. type=str,
  239. action=NodeAction,
  240. default=None,
  241. required=False,
  242. help="""
  243. Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
  244. Usage: --output_nodes output_1:0,output_2:0
  245. """)
  246. parser.add_argument(
  247. '--output',
  248. type=str,
  249. action=OutputDirAction,
  250. default=os.path.join(os.getcwd(), 'output'),
  251. help="""
  252. Optional, specify path for converted script file
  253. directory. Default output directory is `output` folder
  254. in the current working directory.
  255. """)
  256. parser.add_argument(
  257. '--report',
  258. type=str,
  259. action=LogFileAction,
  260. default=None,
  261. help="""
  262. Optional, specify report directory. Default is
  263. converted script directory.
  264. """)
  265. parser.add_argument(
  266. '--project_path',
  267. type=str,
  268. action=ProjectPathAction,
  269. required=False,
  270. default=None,
  271. help="""
  272. Optional, PyTorch scripts project path. If PyTorch
  273. project is not in PYTHONPATH, please assign
  274. `--project_path` when use graph based schema.
  275. Usage: --project_path ~/script_file/
  276. """)
  277. def cli_entry():
  278. """Entry point for mindconverter CLI."""
  279. permissions = os.R_OK | os.W_OK | os.X_OK
  280. os.umask(permissions << 3 | permissions)
  281. argv = sys.argv[1:]
  282. if not argv:
  283. argv = ['-h']
  284. args = parser.parse_args(argv)
  285. else:
  286. args = parser.parse_args()
  287. mode = permissions << 6
  288. os.makedirs(args.output, mode=mode, exist_ok=True)
  289. if args.report is None:
  290. args.report = args.output
  291. os.makedirs(args.report, mode=mode, exist_ok=True)
  292. _run(args.in_file, args.model_file,
  293. args.shape,
  294. args.input_nodes, args.output_nodes,
  295. args.output, args.report,
  296. args.project_path)
  297. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
  298. """
  299. Run converter command.
  300. Args:
  301. in_files (str): The file path or directory to convert.
  302. model_file(str): The pytorch .pth to convert on graph based schema.
  303. shape(list): The input tensor shape of module_file.
  304. input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
  305. output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
  306. out_dir (str): The output directory to save converted file.
  307. report (str): The report file path.
  308. project_path(str): Pytorch scripts project path.
  309. """
  310. if in_files:
  311. files_config = {
  312. 'root_path': in_files,
  313. 'in_files': [],
  314. 'outfile_dir': out_dir,
  315. 'report_dir': report if report else out_dir
  316. }
  317. if os.path.isfile(in_files):
  318. files_config['root_path'] = os.path.dirname(in_files)
  319. files_config['in_files'] = [in_files]
  320. else:
  321. for root_dir, _, files in os.walk(in_files):
  322. for file in files:
  323. files_config['in_files'].append(os.path.join(root_dir, file))
  324. main(files_config)
  325. elif model_file:
  326. file_config = {
  327. 'model_file': model_file,
  328. 'shape': shape if shape else [],
  329. 'input_nodes': input_nodes,
  330. 'output_nodes': output_nodes,
  331. 'outfile_dir': out_dir,
  332. 'report_dir': report if report else out_dir
  333. }
  334. if project_path:
  335. paths = sys.path
  336. if project_path not in paths:
  337. sys.path.append(project_path)
  338. main_graph_base_converter(file_config)
  339. else:
  340. error_msg = "`--in_file` and `--model_file` should be set at least one."
  341. error = FileNotFoundError(error_msg)
  342. log.error(str(error))
  343. log.exception(error)
  344. raise error