Browse Source

!641 Optimize Codes to solve the problem that functions are too long.

Merge pull request !641 from moran/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
398cf46ea7
2 changed files with 68 additions and 66 deletions
  1. +49
    -48
      mindinsight/mindconverter/cli.py
  2. +19
    -18
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py

+ 49
- 48
mindinsight/mindconverter/cli.py View File

@@ -28,12 +28,12 @@ class FileDirAction(argparse.Action):
"""File directory action class definition.""" """File directory action class definition."""


@staticmethod @staticmethod
def check_path(parser, values, option_string=None):
def check_path(parser_in, values, option_string=None):
""" """
Check argument for file path. Check argument for file path.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
@@ -45,22 +45,22 @@ class FileDirAction(argparse.Action):
outfile = os.path.realpath(os.path.join(os.getcwd(), outfile)) outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))


if os.path.exists(outfile) and not os.access(outfile, os.R_OK): if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
parser.error(f'{option_string} {outfile} not accessible')
parser_in.error(f'{option_string} {outfile} not accessible')
return outfile return outfile


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from argparse.Action. Inherited __call__ method from argparse.Action.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile_dir = self.check_path(parser, values, option_string)
outfile_dir = self.check_path(parser_in, values, option_string)
if os.path.isfile(outfile_dir): if os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is a file')
parser_in.error(f'{option_string} {outfile_dir} is a file')


setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)


@@ -68,12 +68,12 @@ class FileDirAction(argparse.Action):
class OutputDirAction(argparse.Action): class OutputDirAction(argparse.Action):
"""File directory action class definition.""" """File directory action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from argparse.Action. Inherited __call__ method from argparse.Action.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
@@ -87,10 +87,10 @@ class OutputDirAction(argparse.Action):


if os.path.exists(output): if os.path.exists(output):
if not os.access(output, os.R_OK): if not os.access(output, os.R_OK):
parser.error(f'{option_string} {output} not accessible')
parser_in.error(f'{option_string} {output} not accessible')


if os.path.isfile(output): if os.path.isfile(output):
parser.error(f'{option_string} {output} is a file')
parser_in.error(f'{option_string} {output} is a file')


setattr(namespace, self.dest, output) setattr(namespace, self.dest, output)


@@ -98,21 +98,21 @@ class OutputDirAction(argparse.Action):
class ProjectPathAction(argparse.Action): class ProjectPathAction(argparse.Action):
"""Project directory action class definition.""" """Project directory action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from argparse.Action. Inherited __call__ method from argparse.Action.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile_dir = FileDirAction.check_path(parser, values, option_string)
outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
parser_in.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isdir(outfile_dir): if not os.path.isdir(outfile_dir):
parser.error(f'{option_string} [{outfile_dir}] should be a directory.')
parser_in.error(f'{option_string} [{outfile_dir}] should be a directory.')


setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)


@@ -120,22 +120,22 @@ class ProjectPathAction(argparse.Action):
class InFileAction(argparse.Action): class InFileAction(argparse.Action):
"""Input File action class definition.""" """Input File action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from argparse.Action. Inherited __call__ method from argparse.Action.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile_dir = FileDirAction.check_path(parser, values, option_string)
outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
parser_in.error(f'{option_string} {outfile_dir} not exists')


if not os.path.isfile(outfile_dir): if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
parser_in.error(f'{option_string} {outfile_dir} is not a file')


setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)


@@ -143,25 +143,25 @@ class InFileAction(argparse.Action):
class ModelFileAction(argparse.Action): class ModelFileAction(argparse.Action):
"""Model File action class definition.""" """Model File action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from argparse.Action. Inherited __call__ method from argparse.Action.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile_dir = FileDirAction.check_path(parser, values, option_string)
outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
parser_in.error(f'{option_string} {outfile_dir} not exists')


if not os.path.isfile(outfile_dir): if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
parser_in.error(f'{option_string} {outfile_dir} is not a file')


if not outfile_dir.endswith('.pth'): if not outfile_dir.endswith('.pth'):
parser.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")
parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")


setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)


@@ -169,31 +169,31 @@ class ModelFileAction(argparse.Action):
class LogFileAction(argparse.Action): class LogFileAction(argparse.Action):
"""Log file action class definition.""" """Log file action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from FileDirAction. Inherited __call__ method from FileDirAction.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile_dir = FileDirAction.check_path(parser, values, option_string)
outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir): if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a directory')
parser_in.error(f'{option_string} {outfile_dir} is not a directory')
setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)




class ShapeAction(argparse.Action): class ShapeAction(argparse.Action):
"""Shape action class definition.""" """Shape action class definition."""


def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser_in, namespace, values, option_string=None):
""" """
Inherited __call__ method from FileDirAction. Inherited __call__ method from FileDirAction.


Args: Args:
parser (ArgumentParser): Passed-in argument parser.
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments. namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
@@ -203,27 +203,21 @@ class ShapeAction(argparse.Action):
try: try:
in_shape = [int(num_shape) for num_shape in shape_str.split(',')] in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
except ValueError: except ValueError:
parser.error(
parser_in.error(
f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
setattr(namespace, self.dest, in_shape) setattr(namespace, self.dest, in_shape)




def cli_entry():
"""Entry point for mindconverter CLI."""

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

parser = argparse.ArgumentParser(
parser = argparse.ArgumentParser(
prog='mindconverter', prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))


parser.add_argument(
parser.add_argument(
'--version', '--version',
action='version', action='version',
version='%(prog)s ({})'.format(mindinsight.__version__)) version='%(prog)s ({})'.format(mindinsight.__version__))


parser.add_argument(
parser.add_argument(
'--in_file', '--in_file',
type=str, type=str,
action=InFileAction, action=InFileAction,
@@ -234,7 +228,7 @@ def cli_entry():
do script conversation. do script conversation.
""") """)


parser.add_argument(
parser.add_argument(
'--model_file', '--model_file',
type=str, type=str,
action=ModelFileAction, action=ModelFileAction,
@@ -246,7 +240,7 @@ def cli_entry():
use AST schema as default. use AST schema as default.
""") """)


parser.add_argument(
parser.add_argument(
'--shape', '--shape',
type=str, type=str,
action=ShapeAction, action=ShapeAction,
@@ -259,7 +253,7 @@ def cli_entry():
Usage: --shape 3,244,244 Usage: --shape 3,244,244
""") """)


parser.add_argument(
parser.add_argument(
'--output', '--output',
type=str, type=str,
action=OutputDirAction, action=OutputDirAction,
@@ -270,7 +264,7 @@ def cli_entry():
in the current working directory. in the current working directory.
""") """)


parser.add_argument(
parser.add_argument(
'--report', '--report',
type=str, type=str,
action=LogFileAction, action=LogFileAction,
@@ -280,7 +274,7 @@ def cli_entry():
converted script directory. converted script directory.
""") """)


parser.add_argument(
parser.add_argument(
'--project_path', '--project_path',
type=str, type=str,
action=ProjectPathAction, action=ProjectPathAction,
@@ -293,6 +287,13 @@ def cli_entry():
Usage: --project_path ~/script_file/ Usage: --project_path ~/script_file/
""") """)



def cli_entry():
"""Entry point for mindconverter CLI."""

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)

argv = sys.argv[1:] argv = sys.argv[1:]
if not argv: if not argv:
argv = ['-h'] argv = ['-h']


+ 19
- 18
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -103,6 +103,24 @@ class PyTorchGraph(Graph):
log.exception(error) log.exception(error)
raise error raise error


@staticmethod
def _extract_shape(shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
if "," not in shape:
return []
for s in shape.split(","):
if not s:
return []
return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')]

def build(self, input_shape): def build(self, input_shape):
""" """
Build graph tree. Build graph tree.
@@ -119,23 +137,6 @@ class PyTorchGraph(Graph):


self._check_input_shape(input_shape) self._check_input_shape(input_shape)


def _extract_shape(shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
if "," not in shape:
return []
for s in shape.split(","):
if not s:
return []
return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')]

feed_forward_ipt_shape = (1, *input_shape) feed_forward_ipt_shape = (1, *input_shape)
batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape))


@@ -158,7 +159,7 @@ class PyTorchGraph(Graph):
node_name = normalize_scope_name(node) node_name = normalize_scope_name(node)
output_shape_str_list = re.findall(r'[^()!]+', str(node)) output_shape_str_list = re.findall(r'[^()!]+', str(node))
output_shape_str = output_shape_str_list[1] output_shape_str = output_shape_str_list[1]
output_shape = _extract_shape(output_shape_str)
output_shape = self._extract_shape(output_shape_str)
weight_scope = '.'.join( weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)]', node.scopeName()) re.findall(r'\[([\w\d.]+)]', node.scopeName())
) )


Loading…
Cancel
Save