|
- # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Graph based scripts converter workflow."""
- import os
- import argparse
- from importlib.util import find_spec
-
- import mindinsight
- from mindinsight.mindconverter.common.log import logger as log
- from .mapper import ONNXToMindSporeMapper
- from ..common.exceptions import NodeTypeNotSupport
-
- permissions = os.R_OK | os.W_OK | os.X_OK
- os.umask(permissions << 3 | permissions)
-
- parser = argparse.ArgumentParser(
- prog="MindConverter",
- description="Graph based MindConverter CLI entry point (version: {})".format(
- mindinsight.__version__)
- )
-
- parser.add_argument("--graph", type=str, required=True,
- help="Third party framework's graph path.")
- parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
- help="Input shape of the model.")
- parser.add_argument("--ckpt", type=str, required=False,
- help="Third party framework's checkpoint path.")
- parser.add_argument("--output", type=str, required=True,
- help="Generated scripts output folder path.")
- parser.add_argument("--report", type=str, required=False,
- help="Generated reports output folder path.")
-
-
- def torch_installation_validation(func):
- """
- Validate args of func.
-
- Args:
- func (type): Function.
-
- Returns:
- type, inner function.
- """
-
- def _f(graph_path: str, sample_shape: tuple,
- output_folder: str, report_folder: str = None,
- checkpoint_path: str = None):
- # Check whether pytorch is installed.
- if not find_spec("torch"):
- error = ModuleNotFoundError("PyTorch is required when using graph based "
- "scripts converter, and PyTorch vision must "
- "be consisted with model generation runtime.")
- log.error(str(error))
- log.exception(error)
- raise error
-
- func(graph_path=graph_path, sample_shape=sample_shape,
- output_folder=output_folder, report_folder=report_folder,
- checkpoint_path=checkpoint_path)
-
- return _f
-
-
- @torch_installation_validation
- def graph_based_converter(graph_path: str, sample_shape: tuple,
- output_folder: str, report_folder: str = None,
- checkpoint_path: str = None):
- """
- Graph based scripts converter.
-
- Args:
- graph_path (str): Graph file path.
- sample_shape (tuple): Input shape of the model.
- output_folder (str): Output folder.
- report_folder (str): Report output folder path.
- checkpoint_path (str): Checkpoint file path.
-
- """
- from .third_party_graph import GraphFactory
- from .hierarchical_tree import HierarchicalTreeFactory
-
- graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
- checkpoint=checkpoint_path)
- try:
- hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
- except Exception as e:
- log.exception(e)
- log.error("Error occur when create hierarchical tree.")
- raise NodeTypeNotSupport("This model is not supported now.")
-
- hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
- report_folder=report_folder)
-
-
- def main_graph_base_converter(file_config):
- """
- The entrance for converter, script files will be converted.
-
- Args:
- file_config (dict): The config of file which to convert.
- """
- graph_based_converter(graph_path=file_config['model_file'],
- sample_shape=file_config['shape'],
- output_folder=file_config['outfile_dir'],
- report_folder=file_config['report_dir'])
|