You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cli.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Command module."""
  16. import os
  17. import sys
  18. import argparse
  19. import mindinsight
  20. from mindinsight.mindconverter.converter import main
  21. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  22. from mindinsight.mindconverter.common.log import logger as log
  23. class FileDirAction(argparse.Action):
  24. """File directory action class definition."""
  25. @staticmethod
  26. def check_path(parser_in, values, option_string=None):
  27. """
  28. Check argument for file path.
  29. Args:
  30. parser_in (ArgumentParser): Passed-in argument parser.
  31. values (object): Argument values with type depending on argument definition.
  32. option_string (str): Optional string for specific argument name. Default: None.
  33. """
  34. outfile = values
  35. if outfile.startswith('~'):
  36. outfile = os.path.realpath(os.path.expanduser(outfile))
  37. if not outfile.startswith('/'):
  38. outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
  39. if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
  40. parser_in.error(f'{option_string} {outfile} not accessible')
  41. return outfile
  42. def __call__(self, parser_in, namespace, values, option_string=None):
  43. """
  44. Inherited __call__ method from argparse.Action.
  45. Args:
  46. parser_in (ArgumentParser): Passed-in argument parser.
  47. namespace (Namespace): Namespace object to hold arguments.
  48. values (object): Argument values with type depending on argument definition.
  49. option_string (str): Optional string for specific argument name. Default: None.
  50. """
  51. outfile_dir = self.check_path(parser_in, values, option_string)
  52. if os.path.isfile(outfile_dir):
  53. parser_in.error(f'{option_string} {outfile_dir} is a file')
  54. setattr(namespace, self.dest, outfile_dir)
  55. class OutputDirAction(argparse.Action):
  56. """File directory action class definition."""
  57. def __call__(self, parser_in, namespace, values, option_string=None):
  58. """
  59. Inherited __call__ method from argparse.Action.
  60. Args:
  61. parser_in (ArgumentParser): Passed-in argument parser.
  62. namespace (Namespace): Namespace object to hold arguments.
  63. values (object): Argument values with type depending on argument definition.
  64. option_string (str): Optional string for specific argument name. Default: None.
  65. """
  66. output = values
  67. if output.startswith('~'):
  68. output = os.path.realpath(os.path.expanduser(output))
  69. if not output.startswith('/'):
  70. output = os.path.realpath(os.path.join(os.getcwd(), output))
  71. if os.path.exists(output):
  72. if not os.access(output, os.R_OK):
  73. parser_in.error(f'{option_string} {output} not accessible')
  74. if os.path.isfile(output):
  75. parser_in.error(f'{option_string} {output} is a file')
  76. setattr(namespace, self.dest, output)
  77. class ProjectPathAction(argparse.Action):
  78. """Project directory action class definition."""
  79. def __call__(self, parser_in, namespace, values, option_string=None):
  80. """
  81. Inherited __call__ method from argparse.Action.
  82. Args:
  83. parser_in (ArgumentParser): Passed-in argument parser.
  84. namespace (Namespace): Namespace object to hold arguments.
  85. values (object): Argument values with type depending on argument definition.
  86. option_string (str): Optional string for specific argument name. Default: None.
  87. """
  88. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  89. if not os.path.exists(outfile_dir):
  90. parser_in.error(f'{option_string} {outfile_dir} not exists')
  91. if not os.path.isdir(outfile_dir):
  92. parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')
  93. setattr(namespace, self.dest, outfile_dir)
  94. class InFileAction(argparse.Action):
  95. """Input File action class definition."""
  96. def __call__(self, parser_in, namespace, values, option_string=None):
  97. """
  98. Inherited __call__ method from argparse.Action.
  99. Args:
  100. parser_in (ArgumentParser): Passed-in argument parser.
  101. namespace (Namespace): Namespace object to hold arguments.
  102. values (object): Argument values with type depending on argument definition.
  103. option_string (str): Optional string for specific argument name. Default: None.
  104. """
  105. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  106. if not os.path.exists(outfile_dir):
  107. parser_in.error(f'{option_string} {outfile_dir} not exists')
  108. if not os.path.isfile(outfile_dir):
  109. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  110. setattr(namespace, self.dest, outfile_dir)
  111. class ModelFileAction(argparse.Action):
  112. """Model File action class definition."""
  113. def __call__(self, parser_in, namespace, values, option_string=None):
  114. """
  115. Inherited __call__ method from argparse.Action.
  116. Args:
  117. parser_in (ArgumentParser): Passed-in argument parser.
  118. namespace (Namespace): Namespace object to hold arguments.
  119. values (object): Argument values with type depending on argument definition.
  120. option_string (str): Optional string for specific argument name. Default: None.
  121. """
  122. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  123. if not os.path.exists(outfile_dir):
  124. parser_in.error(f'{option_string} {outfile_dir} not exists')
  125. if not os.path.isfile(outfile_dir):
  126. parser_in.error(f'{option_string} {outfile_dir} is not a file')
  127. if not outfile_dir.endswith('.pth'):
  128. parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")
  129. setattr(namespace, self.dest, outfile_dir)
  130. class LogFileAction(argparse.Action):
  131. """Log file action class definition."""
  132. def __call__(self, parser_in, namespace, values, option_string=None):
  133. """
  134. Inherited __call__ method from FileDirAction.
  135. Args:
  136. parser_in (ArgumentParser): Passed-in argument parser.
  137. namespace (Namespace): Namespace object to hold arguments.
  138. values (object): Argument values with type depending on argument definition.
  139. option_string (str): Optional string for specific argument name. Default: None.
  140. """
  141. outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
  142. if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
  143. parser_in.error(f'{option_string} {outfile_dir} is not a directory')
  144. setattr(namespace, self.dest, outfile_dir)
  145. class ShapeAction(argparse.Action):
  146. """Shape action class definition."""
  147. def __call__(self, parser_in, namespace, values, option_string=None):
  148. """
  149. Inherited __call__ method from FileDirAction.
  150. Args:
  151. parser_in (ArgumentParser): Passed-in argument parser.
  152. namespace (Namespace): Namespace object to hold arguments.
  153. values (object): Argument values with type depending on argument definition.
  154. option_string (str): Optional string for specific argument name. Default: None.
  155. """
  156. in_shape = None
  157. shape_str = values
  158. try:
  159. in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
  160. except ValueError:
  161. parser_in.error(
  162. f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
  163. setattr(namespace, self.dest, in_shape)
  164. parser = argparse.ArgumentParser(
  165. prog='mindconverter',
  166. description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
  167. parser.add_argument(
  168. '--version',
  169. action='version',
  170. version='%(prog)s ({})'.format(mindinsight.__version__))
  171. parser.add_argument(
  172. '--in_file',
  173. type=str,
  174. action=InFileAction,
  175. required=False,
  176. default=None,
  177. help="""
  178. Specify path for script file to use AST schema to
  179. do script conversation.
  180. """)
  181. parser.add_argument(
  182. '--model_file',
  183. type=str,
  184. action=ModelFileAction,
  185. required=False,
  186. help="""
  187. PyTorch .pth model file path to use graph
  188. based schema to do script generation. When
  189. `--in_file` and `--model_file` are both provided,
  190. use AST schema as default.
  191. """)
  192. parser.add_argument(
  193. '--shape',
  194. type=str,
  195. action=ShapeAction,
  196. default=None,
  197. required=False,
  198. help="""
  199. Optional, expected input tensor shape of
  200. `--model_file`. It's required when use graph based
  201. schema.
  202. Usage: --shape 3,244,244
  203. """)
  204. parser.add_argument(
  205. '--output',
  206. type=str,
  207. action=OutputDirAction,
  208. default=os.path.join(os.getcwd(), 'output'),
  209. help="""
  210. Optional, specify path for converted script file
  211. directory. Default output directory is `output` folder
  212. in the current working directory.
  213. """)
  214. parser.add_argument(
  215. '--report',
  216. type=str,
  217. action=LogFileAction,
  218. default=None,
  219. help="""
  220. Optional, specify report directory. Default is
  221. converted script directory.
  222. """)
  223. parser.add_argument(
  224. '--project_path',
  225. type=str,
  226. action=ProjectPathAction,
  227. required=False,
  228. default=None,
  229. help="""
  230. Optional, PyTorch scripts project path. If PyTorch
  231. project is not in PYTHONPATH, please assign
  232. `--project_path` when use graph based schema.
  233. Usage: --project_path ~/script_file/
  234. """)
  235. def cli_entry():
  236. """Entry point for mindconverter CLI."""
  237. permissions = os.R_OK | os.W_OK | os.X_OK
  238. os.umask(permissions << 3 | permissions)
  239. argv = sys.argv[1:]
  240. if not argv:
  241. argv = ['-h']
  242. args = parser.parse_args(argv)
  243. else:
  244. args = parser.parse_args()
  245. mode = permissions << 6
  246. os.makedirs(args.output, mode=mode, exist_ok=True)
  247. if args.report is None:
  248. args.report = args.output
  249. os.makedirs(args.report, mode=mode, exist_ok=True)
  250. _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path)
  251. def _run(in_files, model_file, shape, out_dir, report, project_path):
  252. """
  253. Run converter command.
  254. Args:
  255. in_files (str): The file path or directory to convert.
  256. model_file(str): The pytorch .pth to convert on graph based schema.
  257. shape(list): The input tensor shape of module_file.
  258. out_dir (str): The output directory to save converted file.
  259. report (str): The report file path.
  260. project_path(str): Pytorch scripts project path.
  261. """
  262. if in_files:
  263. files_config = {
  264. 'root_path': in_files,
  265. 'in_files': [],
  266. 'outfile_dir': out_dir,
  267. 'report_dir': report if report else out_dir
  268. }
  269. if os.path.isfile(in_files):
  270. files_config['root_path'] = os.path.dirname(in_files)
  271. files_config['in_files'] = [in_files]
  272. else:
  273. for root_dir, _, files in os.walk(in_files):
  274. for file in files:
  275. files_config['in_files'].append(os.path.join(root_dir, file))
  276. main(files_config)
  277. elif model_file:
  278. file_config = {
  279. 'model_file': model_file,
  280. 'shape': shape if shape else [],
  281. 'outfile_dir': out_dir,
  282. 'report_dir': report if report else out_dir
  283. }
  284. if project_path:
  285. paths = sys.path
  286. if project_path not in paths:
  287. sys.path.append(project_path)
  288. main_graph_base_converter(file_config)
  289. else:
  290. error_msg = "`--in_file` and `--model_file` should be set at least one."
  291. error = FileNotFoundError(error_msg)
  292. log.error(str(error))
  293. log.exception(error)
  294. raise error