You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import argparse
  18. from importlib.util import find_spec
  19. import mindinsight
  20. from mindinsight.mindconverter.common.log import logger as log
  21. from .mapper import ONNXToMindSporeMapper
  22. from ..common.exceptions import NodeTypeNotSupport
  23. permissions = os.R_OK | os.W_OK | os.X_OK
  24. os.umask(permissions << 3 | permissions)
  25. parser = argparse.ArgumentParser(
  26. prog="MindConverter",
  27. description="Graph based MindConverter CLI entry point (version: {})".format(
  28. mindinsight.__version__)
  29. )
  30. parser.add_argument("--graph", type=str, required=True,
  31. help="Third party framework's graph path.")
  32. parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
  33. help="Input shape of the model.")
  34. parser.add_argument("--ckpt", type=str, required=False,
  35. help="Third party framework's checkpoint path.")
  36. parser.add_argument("--output", type=str, required=True,
  37. help="Generated scripts output folder path.")
  38. parser.add_argument("--report", type=str, required=False,
  39. help="Generated reports output folder path.")
  40. def torch_installation_validation(func):
  41. """
  42. Validate args of func.
  43. Args:
  44. func (type): Function.
  45. Returns:
  46. type, inner function.
  47. """
  48. def _f(graph_path: str, sample_shape: tuple,
  49. output_folder: str, report_folder: str = None,
  50. checkpoint_path: str = None):
  51. # Check whether pytorch is installed.
  52. if not find_spec("torch"):
  53. error = ModuleNotFoundError("PyTorch is required when using graph based "
  54. "scripts converter, and PyTorch vision must "
  55. "be consisted with model generation runtime.")
  56. log.error(str(error))
  57. log.exception(error)
  58. raise error
  59. func(graph_path=graph_path, sample_shape=sample_shape,
  60. output_folder=output_folder, report_folder=report_folder,
  61. checkpoint_path=checkpoint_path)
  62. return _f
  63. @torch_installation_validation
  64. def graph_based_converter(graph_path: str, sample_shape: tuple,
  65. output_folder: str, report_folder: str = None,
  66. checkpoint_path: str = None):
  67. """
  68. Graph based scripts converter.
  69. Args:
  70. graph_path (str): Graph file path.
  71. sample_shape (tuple): Input shape of the model.
  72. output_folder (str): Output folder.
  73. report_folder (str): Report output folder path.
  74. checkpoint_path (str): Checkpoint file path.
  75. """
  76. from .third_party_graph import GraphFactory
  77. from .hierarchical_tree import HierarchicalTreeFactory
  78. graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
  79. checkpoint=checkpoint_path)
  80. try:
  81. hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
  82. except Exception as e:
  83. log.exception(e)
  84. log.error("Error occur when create hierarchical tree.")
  85. raise NodeTypeNotSupport("This model is not supported now.")
  86. hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
  87. report_folder=report_folder)
  88. def main_graph_base_converter(file_config):
  89. """
  90. The entrance for converter, script files will be converted.
  91. Args:
  92. file_config (dict): The config of file which to convert.
  93. """
  94. graph_based_converter(graph_path=file_config['model_file'],
  95. sample_shape=file_config['shape'],
  96. output_folder=file_config['outfile_dir'],
  97. report_folder=file_config['report_dir'])