Browse Source

Add cli & File Permission

tags/v1.0.0
moran 5 years ago
parent
commit
a8cbed83d0
7 changed files with 263 additions and 60 deletions
  1. +132
    -18
      mindinsight/mindconverter/cli.py
  2. +18
    -10
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +28
    -10
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  4. +6
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  5. +20
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py
  6. +59
    -11
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  7. +0
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py

+ 132
- 18
mindinsight/mindconverter/cli.py View File

@@ -19,6 +19,9 @@ import argparse


import mindinsight import mindinsight
from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter

from mindinsight.mindconverter.common.log import logger as log




class FileDirAction(argparse.Action): class FileDirAction(argparse.Action):
@@ -92,6 +95,26 @@ class OutputDirAction(argparse.Action):
setattr(namespace, self.dest, output) setattr(namespace, self.dest, output)




class ProjectPathAction(argparse.Action):
"""Project directory action class definition."""

def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.isdir(outfile_dir):
parser.error(f'{option_string} [{outfile_dir}] should be a directory.')

setattr(namespace, self.dest, outfile_dir)


class InFileAction(argparse.Action): class InFileAction(argparse.Action):
"""Input File action class definition.""" """Input File action class definition."""


@@ -134,6 +157,29 @@ class LogFileAction(argparse.Action):
setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)




class ShapeAction(argparse.Action):
"""Shape action class definition."""

def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.

Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
in_shape = None
shape_str = values
try:
in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
except ValueError:
parser.error(
f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
setattr(namespace, self.dest, in_shape)


def cli_entry(): def cli_entry():
"""Entry point for mindconverter CLI.""" """Entry point for mindconverter CLI."""


@@ -153,9 +199,36 @@ def cli_entry():
'--in_file', '--in_file',
type=str, type=str,
action=InFileAction, action=InFileAction,
required=True,
required=False,
default=None,
help="""
Specify path for script file.
""")

parser.add_argument(
'--model_file',
type=str,
action=InFileAction,
required=False,
help=""" help="""
Specify path for script file.
Pytorch .pth model file path ot use graph
based schema to do script generation. When
`--in_file` and `--model_path` are both provided,
use AST schema as default.
Usage: --model_file ~/pytorch_file/net.pth.
""")

parser.add_argument(
'--shape',
type=str,
action=ShapeAction,
default=None,
required=False,
help="""
Optional, excepted input tensor shape of
`--model_file`. It's required when use graph based
schema.
Usage: --shape 3,244,244
""") """)


parser.add_argument( parser.add_argument(
@@ -172,11 +245,24 @@ def cli_entry():
'--report', '--report',
type=str, type=str,
action=LogFileAction, action=LogFileAction,
default=os.getcwd(),
default=None,
help=""" help="""
Specify report directory. Default is the current working directory. Specify report directory. Default is the current working directory.
""") """)


parser.add_argument(
'--project_path',
type=str,
action=ProjectPathAction,
required=False,
default=None,
help="""
Optional, pytorch scripts project path. If pytorch
project is not in PYTHONPATH, please assign
`--project_path' when use graph based schema.
Usage: --project_path ~/script_file/
""")

argv = sys.argv[1:] argv = sys.argv[1:]
if not argv: if not argv:
argv = ['-h'] argv = ['-h']
@@ -185,30 +271,58 @@ def cli_entry():
args = parser.parse_args() args = parser.parse_args()
mode = permissions << 6 mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True) os.makedirs(args.output, mode=mode, exist_ok=True)
if args.report is None:
args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True) os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, args.report)
_run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path)




def _run(in_files, out_dir, report):
def _run(in_files, model_file, shape, out_dir, report, project_path):
""" """
Run converter command. Run converter command.


Args: Args:
in_files (str): The file path or directory to convert. in_files (str): The file path or directory to convert.
model_file(str): The pytorch .pth to convert on graph based schema.
shape(list): The input tensor shape of module_file.
out_dir (str): The output directory to save converted file. out_dir (str): The output directory to save converted file.
report (str): The report file path. report (str): The report file path.
project_path(str): Pytorch scripts project path.
""" """
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
if in_files:
files_config = {
'root_path': in_files,
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}

if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)

elif model_file:
file_config = {
'model_file': model_file,
'shape': shape if shape else [],
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}
if project_path:
paths = sys.path
if project_path not in paths:
sys.path.append(project_path)

main_graph_base_converter(file_config)

else: else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)
error_msg = "`--in_files` and `--model_file` should be set at least one."
error = FileNotFoundError(error_msg)
log.error(str(error))
log.exception(error)
raise error

+ 18
- 10
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -18,6 +18,7 @@ import argparse
from importlib.util import find_spec from importlib.util import find_spec


import mindinsight import mindinsight
from mindinsight.mindconverter.common.log import logger as log
from .mapper import ONNXToMindSporeMapper from .mapper import ONNXToMindSporeMapper


permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
@@ -57,9 +58,12 @@ def torch_installation_validation(func):
checkpoint_path: str = None): checkpoint_path: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
if not find_spec("torch"): if not find_spec("torch"):
raise ModuleNotFoundError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
error = ModuleNotFoundError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"be consisted with model generation runtime.")
log.error(str(error))
log.exception(error)
raise error


func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder, output_folder=output_folder, report_folder=report_folder,
@@ -93,10 +97,14 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,
report_folder=report_folder) report_folder=report_folder)




if __name__ == '__main__':
args, _ = parser.parse_known_args()
graph_based_converter(graph_path=args.graph,
sample_shape=args.sample_shape,
output_folder=args.output,
report_folder=args.report,
checkpoint_path=args.ckpt)
def main_graph_base_converter(file_config):
"""
The entrance for converter, script files will be converted.

Args:
file_config (dict): The config of file which to convert.
"""
graph_based_converter(graph_path=file_config['model_file'],
sample_shape=file_config['shape'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])

+ 28
- 10
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define hierarchical tree.""" """Define hierarchical tree."""
import os import os
import stat
from copy import deepcopy from copy import deepcopy
from typing import NoReturn, Union from typing import NoReturn, Union
from queue import Queue from queue import Queue
@@ -21,6 +22,8 @@ from queue import Queue
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode
from treelib import Tree, Node from treelib import Tree, Node


from mindinsight.mindconverter.common.log import logger as log

from .name_mgr import ModuleNameMgr, GlobalVarNameMgr from .name_mgr import ModuleNameMgr, GlobalVarNameMgr
from ..mapper.base import Mapper from ..mapper.base import Mapper
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
@@ -34,6 +37,10 @@ GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr()


class HierarchicalTree(Tree): class HierarchicalTree(Tree):
"""Define hierarchical tree.""" """Define hierarchical tree."""
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR
modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR

_root_created = False _root_created = False
ROOT_LEVEL = 0 ROOT_LEVEL = 0


@@ -162,19 +169,31 @@ class HierarchicalTree(Tree):
report_folder = os.path.abspath(report_folder) report_folder = os.path.abspath(report_folder)


if not os.path.exists(out_folder): if not os.path.exists(out_folder):
os.makedirs(out_folder)
os.makedirs(out_folder, self.modes_usr)
if not os.path.exists(report_folder): if not os.path.exists(report_folder):
os.makedirs(report_folder)
os.makedirs(report_folder, self.modes_usr)


for file_name in code_fragments: for file_name in code_fragments:
code, report = code_fragments[file_name] code, report = code_fragments[file_name]
with open(os.path.join(os.path.abspath(out_folder),
f"{file_name}.py"), "w") as file:
file.write(code)

with open(os.path.join(report_folder,
f"report_of_{file_name}.txt"), "w") as rpt_f:
rpt_f.write(report)
try:
with os.fdopen(
os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"),
self.flags, self.modes), 'w') as file:
file.write(code)
except IOError as error:
log.error(str(error))
log.exception(error)
raise error

try:
with os.fdopen(
os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"),
self.flags, stat.S_IRUSR), "w") as rpt_f:
rpt_f.write(report)
except IOError as error:
log.error(str(error))
log.exception(error)
raise error


def _preprocess_node_args(self, node, module_key): def _preprocess_node_args(self, node, module_key):
""" """
@@ -625,7 +644,6 @@ class HierarchicalTree(Tree):
nd_inst = self.get_node(successor_name) nd_inst = self.get_node(successor_name)
# Generate variable name here, then # Generate variable name here, then
# to generate args. # to generate args.
# if nd_inst.data.node_type == NodeType.OPERATION.value:
if created: if created:
nd_inst.data.variable_name = self._module_vars[module_key][idx] nd_inst.data.variable_name = self._module_vars[module_key][idx]
else: else:


+ 6
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -16,6 +16,7 @@
import abc import abc
from collections import OrderedDict from collections import OrderedDict


from mindinsight.mindconverter.common.log import logger as log
from ..constant import SEPARATOR_IN_ONNX_OP from ..constant import SEPARATOR_IN_ONNX_OP
from ..mapper.base import Mapper from ..mapper.base import Mapper


@@ -66,8 +67,11 @@ class BaseGraph(metaclass=abc.ABCMeta):
"""Control the create action of graph.""" """Control the create action of graph."""
model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL)
if not model_param: if not model_param:
raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
f"can not be None.")
error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` "
f"can not be None.")
log.error(str(error))
log.exception(error)
raise error


return super(BaseGraph, cls).__new__(cls) return super(BaseGraph, cls).__new__(cls)




+ 20
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py View File

@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Third party graph parser.""" """Third party graph parser."""
import os import os
from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser from .base import GraphParser




@@ -34,12 +35,24 @@ class PyTorchGraphParser(GraphParser):
import torch import torch


if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")

if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
error = FileNotFoundError("`model_path` must be assigned with "
"an existed file path.")
log.error(str(error))
log.exception(error)
raise error

try:
if torch.cuda.is_available():
model = torch.load(f=model_path)
else:
model = torch.load(f=model_path, map_location="cpu")
except ModuleNotFoundError:
error_msg = \
"Cannot find model scripts in system path, " \
"set `--project_path` to the path of model scripts folder correctly."
error = ModuleNotFoundError(error_msg)
log.error(str(error))
log.exception(error)
raise error


return model return model

+ 59
- 11
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -17,6 +17,7 @@ import warnings
import re import re
from typing import Dict, NoReturn from typing import Dict, NoReturn


from mindinsight.mindconverter.common.log import logger as log
from .base import Graph from .base import Graph
from .input_node import InputNode from .input_node import InputNode
from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_node import PyTorchGraphNode
@@ -89,12 +90,18 @@ class PyTorchGraph(Graph):


""" """
if not input_shape: if not input_shape:
raise ValueError("`input_shape` can not be None.")
error = ValueError("`input_shape` can not be None.")
log.error(str(error))
log.exception(error)
raise error


for item in input_shape: for item in input_shape:
if not isinstance(item, int): if not isinstance(item, int):
raise ValueError(f"Only support model with one input now, "
f"and each shape value in `input_shape` should be int.")
error = ValueError(f"Only support model with one input now, "
f"and each shape value in `input_shape` should be int.")
log.error(str(error))
log.exception(error)
raise error


def build(self, input_shape): def build(self, input_shape):
""" """
@@ -122,9 +129,11 @@ class PyTorchGraph(Graph):
Returns: Returns:
list, shape. list, shape.
""" """
pattern = re.compile(r"\d+:\d*")
if not pattern.findall(shape):
if "," not in shape:
return [] return []
for s in shape.split(","):
if not s:
return []
return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')]


feed_forward_ipt_shape = (1, *input_shape) feed_forward_ipt_shape = (1, *input_shape)
@@ -133,10 +142,15 @@ class PyTorchGraph(Graph):
# Assign execution mode to eval. # Assign execution mode to eval.
self.model.eval() self.model.eval()


with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
try:
with OverloadTorchModuleTemporarily() as _:
# In pytorch higher version, trace function has a known.
graph = onnx_tracer(self.model, batched_sample,
OperatorExportTypes.ONNX)
except RuntimeError as error:
log.error(str(error))
log.exception(error)
raise error


nodes = list(graph.nodes()) nodes = list(graph.nodes())


@@ -190,6 +204,37 @@ class PyTorchGraph(Graph):
""" """
raise NotImplementedError() raise NotImplementedError()


def to_hierarchical_tree(self):
"""
Generate hierarchical tree based on graph.
"""
from ..hierarchical_tree import HierarchicalTree

tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(self.nodes_in_topological_order):
node_inst = self.get_node(node_name)
node_output = self._shape_dict.get(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = self._input_shape.get(node_name)

if not node_input:
error = ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")
log.error(str(error))
log.exception(error)
raise error

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

def build_connection(self, src, tgt) -> NoReturn: def build_connection(self, src, tgt) -> NoReturn:
""" """
Build connection between source node and target node. Build connection between source node and target node.
@@ -229,8 +274,11 @@ class PyTorchGraph(Graph):
""" """
Load graph metadata. Load graph metadata.
""" """
raise NotImplementedError("class `PyTorchGraph` has not implemented "
"`load_metadata()`.")
error = NotImplementedError("class `PyTorchGraph` has not implemented "
"`load_metadata()`.")
log.error(str(error))
log.exception(error)
raise error


@staticmethod @staticmethod
def load_graph(graph_path: str): def load_graph(graph_path: str):


+ 0
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py View File

@@ -116,8 +116,6 @@ class PyTorchGraphNode(GraphNode):
""" """
if not self._module_name_frozen: if not self._module_name_frozen:
module_name = self.tag module_name = self.tag
# if self._node_type == NodeType.CLASS.value:
# module_name = f"{module_name[0].upper()}{module_name[1:]}"
return module_name return module_name


return self._module_name return self._module_name


Loading…
Cancel
Save