You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  22. from mindinsight.mindconverter.common.log import logger as log
  23. class FileDirAction(argparse.Action):
  24. """File directory action class definition."""
  25. @staticmethod
  26. def check_path(parser, values, option_string=None):
  27. """
  28. Check argument for file path.
  29. Args:
  30. parser (ArgumentParser): Passed-in argument parser.
  31. values (object): Argument values with type depending on argument definition.
  32. option_string (str): Optional string for specific argument name. Default: None.
  33. """
  34. outfile = values
  35. if outfile.startswith('~'):
  36. outfile = os.path.realpath(os.path.expanduser(outfile))
  37. if not outfile.startswith('/'):
  38. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  39. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  40. parser.error(f'{option_string} {outfile} not accessible')
  41. return outfile
  42. def __call__(self, parser, namespace, values, option_string=None):
  43. """
  44. Inherited __call__ method from argparse.Action.
  45. Args:
  46. parser (ArgumentParser): Passed-in argument parser.
  47. namespace (Namespace): Namespace object to hold arguments.
  48. values (object): Argument values with type depending on argument definition.
  49. option_string (str): Optional string for specific argument name. Default: None.
  50. """
  51. outfile_dir = self.check_path(parser, values, option_string)
  52. if os.path.isfile(outfile_dir):
  53. parser.error(f'{option_string} {outfile_dir} is a file')
  54. setattr(namespace, self.dest, outfile_dir)
  55. class OutputDirAction(argparse.Action):
  56. """File directory action class definition."""
  57. def __call__(self, parser, namespace, values, option_string=None):
  58. """
  59. Inherited __call__ method from argparse.Action.
  60. Args:
  61. parser (ArgumentParser): Passed-in argument parser.
  62. namespace (Namespace): Namespace object to hold arguments.
  63. values (object): Argument values with type depending on argument definition.
  64. option_string (str): Optional string for specific argument name. Default: None.
  65. """
  66. output = values
  67. if output.startswith('~'):
  68. output = os.path.realpath(os.path.expanduser(output))
  69. if not output.startswith('/'):
  70. output = os.path.realpath(os.path.join(os.getcwd(), output))
  71. if os.path.exists(output):
  72. if not os.access(output, os.R_OK):
  73. parser.error(f'{option_string} {output} not accessible')
  74. if os.path.isfile(output):
  75. parser.error(f'{option_string} {output} is a file')
  76. setattr(namespace, self.dest, output)
  77. class ProjectPathAction(argparse.Action):
  78. """Project directory action class definition."""
  79. def __call__(self, parser, namespace, values, option_string=None):
  80. """
  81. Inherited __call__ method from argparse.Action.
  82. Args:
  83. parser (ArgumentParser): Passed-in argument parser.
  84. namespace (Namespace): Namespace object to hold arguments.
  85. values (object): Argument values with type depending on argument definition.
  86. option_string (str): Optional string for specific argument name. Default: None.
  87. """
  88. outfile_dir = FileDirAction.check_path(parser, values, option_string)
  89. if not os.path.isdir(outfile_dir):
  90. parser.error(f'{option_string} [{outfile_dir}] should be a directory.')
  91. setattr(namespace, self.dest, outfile_dir)
  92. class InFileAction(argparse.Action):
  93. """Input File action class definition."""
  94. def __call__(self, parser, namespace, values, option_string=None):
  95. """
  96. Inherited __call__ method from argparse.Action.
  97. Args:
  98. parser (ArgumentParser): Passed-in argument parser.
  99. namespace (Namespace): Namespace object to hold arguments.
  100. values (object): Argument values with type depending on argument definition.
  101. option_string (str): Optional string for specific argument name. Default: None.
  102. """
  103. outfile_dir = FileDirAction.check_path(parser, values, option_string)
  104. if not os.path.exists(outfile_dir):
  105. parser.error(f'{option_string} {outfile_dir} not exists')
  106. if not os.path.isfile(outfile_dir):
  107. parser.error(f'{option_string} {outfile_dir} is not a file')
  108. setattr(namespace, self.dest, outfile_dir)
  109. class LogFileAction(argparse.Action):
  110. """Log file action class definition."""
  111. def __call__(self, parser, namespace, values, option_string=None):
  112. """
  113. Inherited __call__ method from FileDirAction.
  114. Args:
  115. parser (ArgumentParser): Passed-in argument parser.
  116. namespace (Namespace): Namespace object to hold arguments.
  117. values (object): Argument values with type depending on argument definition.
  118. option_string (str): Optional string for specific argument name. Default: None.
  119. """
  120. outfile_dir = FileDirAction.check_path(parser, values, option_string)
  121. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  122. parser.error(f'{option_string} {outfile_dir} is not a directory')
  123. setattr(namespace, self.dest, outfile_dir)
  124. class ShapeAction(argparse.Action):
  125. """Shape action class definition."""
  126. def __call__(self, parser, namespace, values, option_string=None):
  127. """
  128. Inherited __call__ method from FileDirAction.
  129. Args:
  130. parser (ArgumentParser): Passed-in argument parser.
  131. namespace (Namespace): Namespace object to hold arguments.
  132. values (object): Argument values with type depending on argument definition.
  133. option_string (str): Optional string for specific argument name. Default: None.
  134. """
  135. in_shape = None
  136. shape_str = values
  137. try:
  138. in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
  139. except ValueError:
  140. parser.error(
  141. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  142. setattr(namespace, self.dest, in_shape)
  143. def cli_entry():
  144. """Entry point for mindconverter CLI."""
  145. permissions = os.R_OK | os.W_OK | os.X_OK
  146. os.umask(permissions << 3 | permissions)
  147. parser = argparse.ArgumentParser(
  148. prog='mindconverter',
  149. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
  150. parser.add_argument(
  151. '--version',
  152. action='version',
  153. version='%(prog)s ({})'.format(mindinsight.__version__))
  154. parser.add_argument(
  155. '--in_file',
  156. type=str,
  157. action=InFileAction,
  158. required=False,
  159. default=None,
  160. help="""
  161. Specify path for script file.
  162. """)
  163. parser.add_argument(
  164. '--model_file',
  165. type=str,
  166. action=InFileAction,
  167. required=False,
  168. help="""
  169. Pytorch .pth model file path ot use graph
  170. based schema to do script generation. When
  171. `--in_file` and `--model_path` are both provided,
  172. use AST schema as default.
  173. Usage: --model_file ~/pytorch_file/net.pth.
  174. """)
  175. parser.add_argument(
  176. '--shape',
  177. type=str,
  178. action=ShapeAction,
  179. default=None,
  180. required=False,
  181. help="""
  182. Optional, excepted input tensor shape of
  183. `--model_file`. It's required when use graph based
  184. schema.
  185. Usage: --shape 3,244,244
  186. """)
  187. parser.add_argument(
  188. '--output',
  189. type=str,
  190. action=OutputDirAction,
  191. default=os.path.join(os.getcwd(), 'output'),
  192. help="""
  193. Specify path for converted script file directory.
  194. Default is output directory in the current working directory.
  195. """)
  196. parser.add_argument(
  197. '--report',
  198. type=str,
  199. action=LogFileAction,
  200. default=None,
  201. help="""
  202. Specify report directory. Default is the current working directory.
  203. """)
  204. parser.add_argument(
  205. '--project_path',
  206. type=str,
  207. action=ProjectPathAction,
  208. required=False,
  209. default=None,
  210. help="""
  211. Optional, pytorch scripts project path. If pytorch
  212. project is not in PYTHONPATH, please assign
  213. `--project_path' when use graph based schema.
  214. Usage: --project_path ~/script_file/
  215. """)
  216. argv = sys.argv[1:]
  217. if not argv:
  218. argv = ['-h']
  219. args = parser.parse_args(argv)
  220. else:
  221. args = parser.parse_args()
  222. mode = permissions << 6
  223. os.makedirs(args.output, mode=mode, exist_ok=True)
  224. if args.report is None:
  225. args.report = args.output
  226. os.makedirs(args.report, mode=mode, exist_ok=True)
  227. _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path)
  228. def _run(in_files, model_file, shape, out_dir, report, project_path):
  229. """
  230. Run converter command.
  231. Args:
  232. in_files (str): The file path or directory to convert.
  233. model_file(str): The pytorch .pth to convert on graph based schema.
  234. shape(list): The input tensor shape of module_file.
  235. out_dir (str): The output directory to save converted file.
  236. report (str): The report file path.
  237. project_path(str): Pytorch scripts project path.
  238. """
  239. if in_files:
  240. files_config = {
  241. 'root_path': in_files,
  242. 'in_files': [],
  243. 'outfile_dir': out_dir,
  244. 'report_dir': report if report else out_dir
  245. }
  246. if os.path.isfile(in_files):
  247. files_config['root_path'] = os.path.dirname(in_files)
  248. files_config['in_files'] = [in_files]
  249. else:
  250. for root_dir, _, files in os.walk(in_files):
  251. for file in files:
  252. files_config['in_files'].append(os.path.join(root_dir, file))
  253. main(files_config)
  254. elif model_file:
  255. file_config = {
  256. 'model_file': model_file,
  257. 'shape': shape if shape else [],
  258. 'outfile_dir': out_dir,
  259. 'report_dir': report if report else out_dir
  260. }
  261. if project_path:
  262. paths = sys.path
  263. if project_path not in paths:
  264. sys.path.append(project_path)
  265. main_graph_base_converter(file_config)
  266. else:
  267. error_msg = "`--in_files` and `--model_file` should be set at least one."
  268. error = FileNotFoundError(error_msg)
  269. log.error(str(error))
  270. log.exception(error)
  271. raise error