Browse Source

!815 Fix bug to add multi-args check in MindConverter.

Merge pull request !815 from moran/network_validation
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c631869061
2 changed files with 42 additions and 4 deletions
  1. +41
    -3
      mindinsight/mindconverter/cli.py
  2. +1
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py

+ 41
- 3
mindinsight/mindconverter/cli.py View File

@@ -19,12 +19,22 @@ import argparse


import mindinsight import mindinsight
from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log




class ArgsCheck:
"""Args check."""

@staticmethod
def check_repeated(namespace, dest, default, option_string, parser_in):
"""Check repeated."""
if getattr(namespace, dest, default) is not default:
parser_in.error(f'Parameter `{option_string}` is set repeatedly.')


class FileDirAction(argparse.Action): class FileDirAction(argparse.Action):
"""File directory action class definition.""" """File directory action class definition."""


@@ -64,6 +74,9 @@ class FileDirAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

outfile_dir = self.check_path(parser_in, values, option_string) outfile_dir = self.check_path(parser_in, values, option_string)
if os.path.isfile(outfile_dir): if os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is a file') parser_in.error(f'{option_string} {outfile_dir} is a file')
@@ -84,6 +97,9 @@ class OutputDirAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

output = values output = values


if len(output) > ARGUMENT_LENGTH_LIMIT: if len(output) > ARGUMENT_LENGTH_LIMIT:
@@ -119,6 +135,9 @@ class ProjectPathAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

outfile_dir = FileDirAction.check_path(parser_in, values, option_string) outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} not exists') parser_in.error(f'{option_string} {outfile_dir} not exists')
@@ -141,6 +160,9 @@ class InFileAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

outfile_dir = FileDirAction.check_path(parser_in, values, option_string) outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} not exists') parser_in.error(f'{option_string} {outfile_dir} not exists')
@@ -164,6 +186,8 @@ class ModelFileAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

outfile_dir = FileDirAction.check_path(parser_in, values, option_string) outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if not os.path.exists(outfile_dir): if not os.path.exists(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} not exists') parser_in.error(f'{option_string} {outfile_dir} not exists')
@@ -187,6 +211,9 @@ class LogFileAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

outfile_dir = FileDirAction.check_path(parser_in, values, option_string) outfile_dir = FileDirAction.check_path(parser_in, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir): if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a directory') parser_in.error(f'{option_string} {outfile_dir} is not a directory')
@@ -206,11 +233,14 @@ class ShapeAction(argparse.Action):
values (object): Argument values with type depending on argument definition. values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

in_shape = None in_shape = None
shape_str = values shape_str = values


shape_list = shape_str.split(':') shape_list = shape_str.split(':')
if not len(shape_list) == EXPECTED_SHAPE_NUMBER:
if not len(shape_list) == EXPECTED_NUMBER:
parser_in.error(f"Only support one shape now, but get {len(shape_list)}.") parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")


try: try:
@@ -235,18 +265,26 @@ class NodeAction(argparse.Action):
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.


""" """

ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)

node_str = values node_str = values
if len(node_str) > ARGUMENT_LENGTH_LIMIT: if len(node_str) > ARGUMENT_LENGTH_LIMIT:
parser_in.error( parser_in.error(
f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}." f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
) )


node_list = node_str.split(',')
if not len(node_list) == EXPECTED_NUMBER:
parser_in.error(f"Only support one {option_string} now, but get {len(node_list)}.")

setattr(namespace, self.dest, node_str) setattr(namespace, self.dest, node_str)




parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='mindconverter', prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__),
allow_abbrev=False)


parser.add_argument( parser.add_argument(
'--version', '--version',


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -42,7 +42,7 @@ BINARY_HEADER_PYTORCH_BITS = 32


ARGUMENT_LENGTH_LIMIT = 512 ARGUMENT_LENGTH_LIMIT = 512


EXPECTED_SHAPE_NUMBER = 1
EXPECTED_NUMBER = 1




@unique @unique


Loading…
Cancel
Save