You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type
  22. from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, \
  23. ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType
  24. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  25. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  26. class ArgsCheck:
  27. """Args check."""
  28. @staticmethod
  29. def check_repeated(namespace, dest, default, option_string, parser_in):
  30. """Check repeated."""
  31. if getattr(namespace, dest, default) is not default:
  32. parser_in.error(f'Parameter `{option_string}` is set repeatedly.')
  33. class FileDirAction(argparse.Action):
  34. """File directory action class definition."""
  35. @staticmethod
  36. def check_path(parser_in, values, option_string=None):
  37. """
  38. Check argument for file path.
  39. Args:
  40. parser_in (ArgumentParser): Passed-in argument parser.
  41. values (object): Argument values with type depending on argument definition.
  42. option_string (str): Optional string for specific argument name. Default: None.
  43. """
  44. outfile = values
  45. if len(outfile) > ARGUMENT_LENGTH_LIMIT:
  46. parser_in.error(
  47. f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  48. if outfile.startswith('~'):
  49. outfile = os.path.realpath(os.path.expanduser(outfile))
  50. if not outfile.startswith('/'):
  51. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  52. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  53. parser_in.error(f'{option_string} {outfile} not accessible')
  54. return outfile
  55. def __call__(self, parser_in, namespace, values, option_string=None):
  56. """
  57. Inherited __call__ method from argparse.Action.
  58. Args:
  59. parser_in (ArgumentParser): Passed-in argument parser.
  60. namespace (Namespace): Namespace object to hold arguments.
  61. values (object): Argument values with type depending on argument definition.
  62. option_string (str): Optional string for specific argument name. Default: None.
  63. """
  64. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  65. outfile_dir = self.check_path(parser_in, values, option_string)
  66. if os.path.isfile(outfile_dir):
  67. parser_in.error(f'{option_string} {outfile_dir} is a file')
  68. setattr(namespace, self.dest, outfile_dir)
  69. class OutputDirAction(argparse.Action):
  70. """File directory action class definition."""
  71. def __call__(self, parser_in, namespace, values, option_string=None):
  72. """
  73. Inherited __call__ method from argparse.Action.
  74. Args:
  75. parser_in (ArgumentParser): Passed-in argument parser.
  76. namespace (Namespace): Namespace object to hold arguments.
  77. values (object): Argument values with type depending on argument definition.
  78. option_string (str): Optional string for specific argument name. Default: None.
  79. """
  80. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  81. output = values
  82. if len(output) > ARGUMENT_LENGTH_LIMIT:
  83. parser_in.error(
  84. f"The length of {option_string}{output} should be no more than {ARGUMENT_LENGTH_LIMIT}.")
  85. if output.startswith('~'):
  86. output = os.path.realpath(os.path.expanduser(output))
  87. if not output.startswith('/'):
  88. output = os.path.realpath(os.path.join(os.getcwd(), output))
  89. if os.path.exists(output):
  90. if not os.access(output, os.R_OK):
  91. parser_in.error(f'{option_string} {output} not accessible')
  92. if os.path.isfile(output):
  93. parser_in.error(f'{option_string} {output} is a file')
  94. setattr(namespace, self.dest, output)
  95. class InFileAction(argparse.Action):
  96. """Input File action class definition."""
  97. def __call__(self, parser_in, namespace, values, option_string=None):
  98. """
  99. Inherited __call__ method from argparse.Action.
  100. Args:
  101. parser_in (ArgumentParser): Passed-in argument parser.
  102. namespace (Namespace): Namespace object to hold arguments.
  103. values (object): Argument values with type depending on argument definition.
  104. option_string (str): Optional string for specific argument name. Default: None.
  105. """
  106. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  107. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  108. if not os.path.exists(outfile_dir):
  109. parser_in.error(f'{option_string} {outfile_dir} not exists')
  110. if not os.path.isfile(outfile_dir):
  111. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  112. if not os.path.basename(outfile_dir).endswith("py"):
  113. parser_in.error(f'{option_string} {outfile_dir} is not a valid python file')
  114. setattr(namespace, self.dest, outfile_dir)
  115. class ModelFileAction(argparse.Action):
  116. """Model File action class definition."""
  117. def __call__(self, parser_in, namespace, values, option_string=None):
  118. """
  119. Inherited __call__ method from argparse.Action.
  120. Args:
  121. parser_in (ArgumentParser): Passed-in argument parser.
  122. namespace (Namespace): Namespace object to hold arguments.
  123. values (object): Argument values with type depending on argument definition.
  124. option_string (str): Optional string for specific argument name. Default: None.
  125. """
  126. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  127. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  128. if not os.path.exists(outfile_dir):
  129. parser_in.error(f'{option_string} {outfile_dir} not exists')
  130. if not os.path.isfile(outfile_dir):
  131. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  132. frame_type = get_framework_type(outfile_dir)
  133. if frame_type == FrameworkType.UNKNOWN.value:
  134. parser_in.error(f'{option_string} {outfile_dir} should be '
  135. f'a valid TensorFlow(.pb) or an ONNX(.onnx) model file.')
  136. setattr(namespace, self.dest, outfile_dir)
  137. class LogFileAction(argparse.Action):
  138. """Log file action class definition."""
  139. def __call__(self, parser_in, namespace, values, option_string=None):
  140. """
  141. Inherited __call__ method from FileDirAction.
  142. Args:
  143. parser_in (ArgumentParser): Passed-in argument parser.
  144. namespace (Namespace): Namespace object to hold arguments.
  145. values (object): Argument values with type depending on argument definition.
  146. option_string (str): Optional string for specific argument name. Default: None.
  147. """
  148. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  149. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  150. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  151. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  152. setattr(namespace, self.dest, outfile_dir)
  153. class ShapeAction(argparse.Action):
  154. """Shape action class definition."""
  155. def __call__(self, parser_in, namespace, values, option_string=None):
  156. """
  157. Inherited __call__ method from FileDirAction.
  158. Args:
  159. parser_in (ArgumentParser): Passed-in argument parser.
  160. namespace (Namespace): Namespace object to hold arguments.
  161. values (list): Argument values with type depending on argument definition.
  162. option_string (str): Optional string for specific argument name. Default: None.
  163. """
  164. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  165. def _convert_to_int(shape_list):
  166. return [int(num_shape) for num_shape in shape_list.split(',')]
  167. try:
  168. if len(values) > ARGUMENT_NUM_LIMIT:
  169. parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.")
  170. in_shape = []
  171. for v in values:
  172. shape = _convert_to_int(v)
  173. if len(shape) > ARGUMENT_LEN_LIMIT:
  174. parser_in.error(
  175. f"The length of {option_string} {shape} should be no more than {ARGUMENT_LEN_LIMIT}.")
  176. in_shape.append(shape)
  177. setattr(namespace, self.dest, in_shape)
  178. except ValueError:
  179. parser_in.error(
  180. f"{option_string} {values} should be list of integers split by ',', check it please.")
  181. class NodeAction(argparse.Action):
  182. """Node action class definition."""
  183. def __call__(self, parser_in, namespace, values, option_string=None):
  184. """
  185. Inherited __call__ method from FileDirAction.
  186. Args:
  187. parser_in (ArgumentParser): Passed-in argument parser.
  188. namespace (Namespace): Namespace object to hold arguments.
  189. values (list): Argument values with type depending on argument definition.
  190. option_string (str): Optional string for specific argument name. Default: None.
  191. """
  192. ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
  193. if len(values) > ARGUMENT_NUM_LIMIT:
  194. parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.")
  195. deduplicated = set()
  196. abnormal_nodes = []
  197. for v in values:
  198. if len(v) > ARGUMENT_LENGTH_LIMIT:
  199. parser_in.error(
  200. f"The length of {option_string} {v} should be no more than {ARGUMENT_LENGTH_LIMIT}."
  201. )
  202. if v in deduplicated:
  203. abnormal_nodes.append(v)
  204. continue
  205. deduplicated.add(v)
  206. if abnormal_nodes:
  207. parser_in.error(f"{', '.join(abnormal_nodes)} {'is' if len(abnormal_nodes) == 1 else 'are'} duplicated.")
  208. setattr(namespace, self.dest, values)
  209. parser = argparse.ArgumentParser(
  210. prog='mindconverter',
  211. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
  212. allow_abbrev=False)
  213. parser.add_argument(
  214. '--version',
  215. action='version',
  216. version='%(prog)s ({})'.format(mindinsight.__version__))
  217. parser.add_argument(
  218. '--in_file',
  219. type=str,
  220. action=InFileAction,
  221. required=False,
  222. default=None,
  223. help="""
  224. Specify path for script file to use AST schema to
  225. do script conversation.
  226. """)
  227. parser.add_argument(
  228. '--model_file',
  229. type=str,
  230. action=ModelFileAction,
  231. required=False,
  232. help="""
  233. Tensorflow(.pb) or ONNX(.onnx) model file path
  234. is expected to do script generation based on graph schema. When
  235. `--in_file` and `--model_file` are both provided,
  236. use AST schema as default.
  237. """)
  238. parser.add_argument(
  239. '--shape',
  240. type=str,
  241. action=ShapeAction,
  242. default=None,
  243. required=False,
  244. nargs="+",
  245. help="""
  246. Optional, expected input tensor shape of
  247. `--model_file`. It is required when use graph based
  248. schema. Both order and number should be consistent with `--input_nodes`.
  249. Usage: --shape 1,512 1,512
  250. """)
  251. parser.add_argument(
  252. '--input_nodes',
  253. type=str,
  254. action=NodeAction,
  255. default=None,
  256. required=False,
  257. nargs="+",
  258. help="""
  259. Optional, input node(s) name of `--model_file`. It is required when use graph based schema.
  260. Both order and number should be consistent with `--shape`. Usage: --input_nodes input_1:0 input_2:0
  261. """)
  262. parser.add_argument(
  263. '--output_nodes',
  264. type=str,
  265. action=NodeAction,
  266. default=None,
  267. required=False,
  268. nargs="+",
  269. help="""
  270. Optional, output node(s) name of `--model_file`. It is required when use graph based schema.
  271. Usage: --output_nodes output_1:0 output_2:0
  272. """)
  273. parser.add_argument(
  274. '--output',
  275. type=str,
  276. action=OutputDirAction,
  277. default=os.path.join(os.getcwd(), 'output'),
  278. help="""
  279. Optional, specify path for converted script file
  280. directory. Default output directory is `output` folder
  281. in the current working directory.
  282. """)
  283. parser.add_argument(
  284. '--report',
  285. type=str,
  286. action=LogFileAction,
  287. default=None,
  288. help="""
  289. Optional, specify report directory. Default is
  290. converted script directory.
  291. """)
  292. def cli_entry():
  293. """Entry point for mindconverter CLI."""
  294. permissions = os.R_OK | os.W_OK | os.X_OK
  295. os.umask(permissions << 3 | permissions)
  296. argv = sys.argv[1:]
  297. if not argv:
  298. argv = ['-h']
  299. args = parser.parse_args(argv)
  300. else:
  301. args = parser.parse_args()
  302. mode = permissions << 6
  303. os.makedirs(args.output, mode=mode, exist_ok=True)
  304. if args.report is None:
  305. args.report = args.output
  306. os.makedirs(args.report, mode=mode, exist_ok=True)
  307. _run(args.in_file, args.model_file,
  308. args.shape,
  309. args.input_nodes, args.output_nodes,
  310. args.output, args.report)
  311. def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
  312. """
  313. Run converter command.
  314. Args:
  315. in_files (str): The file path or directory to convert.
  316. model_file(str): The model to convert on graph based schema.
  317. shape(list): The input tensor shape of module_file.
  318. input_nodes(str): The input node(s) name of model.
  319. output_nodes(str): The output node(s) name of model.
  320. out_dir (str): The output directory to save converted file.
  321. report (str): The report file path.
  322. """
  323. if in_files:
  324. files_config = {
  325. 'root_path': in_files,
  326. 'in_files': [],
  327. 'outfile_dir': out_dir,
  328. 'report_dir': report if report else out_dir
  329. }
  330. if os.path.isfile(in_files):
  331. files_config['root_path'] = os.path.dirname(in_files)
  332. files_config['in_files'] = [in_files]
  333. else:
  334. for root_dir, _, files in os.walk(in_files):
  335. for file in files:
  336. files_config['in_files'].append(os.path.join(root_dir, file))
  337. main(files_config)
  338. log_console.info("MindConverter: conversion is completed.")
  339. elif model_file:
  340. file_config = {
  341. 'model_file': model_file,
  342. 'shape': shape if shape else [],
  343. 'input_nodes': input_nodes,
  344. 'output_nodes': output_nodes,
  345. 'outfile_dir': out_dir,
  346. 'report_dir': report if report else out_dir
  347. }
  348. main_graph_base_converter(file_config)
  349. log_console.info("MindConverter: conversion is completed.")
  350. else:
  351. error_msg = "`--in_file` and `--model_file` should be set at least one."
  352. error = FileNotFoundError(error_msg)
  353. log.error(str(error))
  354. log_console.error(f"mindconverter: error: {str(error)}")
  355. sys.exit(-1)