From d3d739a6679c547abf6f1c97f967c631afbe6e90 Mon Sep 17 00:00:00 2001 From: moran Date: Tue, 27 Oct 2020 12:07:31 +0800 Subject: [PATCH] Add multi-args check & Add args number check --- mindinsight/mindconverter/cli.py | 44 +++++++++++++++++-- .../graph_based_converter/constant.py | 2 +- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 2cfa877b..c9b6bfa2 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -19,12 +19,22 @@ import argparse import mindinsight from mindinsight.mindconverter.converter import main -from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER +from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.common.log import logger as log +class ArgsCheck: + """Args check.""" + + @staticmethod + def check_repeated(namespace, dest, default, option_string, parser_in): + """Check repeated.""" + if getattr(namespace, dest, default) is not default: + parser_in.error(f'Parameter `{option_string}` is set repeatedly.') + + class FileDirAction(argparse.Action): """File directory action class definition.""" @@ -64,6 +74,9 @@ class FileDirAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + outfile_dir = self.check_path(parser_in, values, option_string) if os.path.isfile(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} is a file') @@ -84,6 +97,9 @@ class OutputDirAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + output = values if len(output) > ARGUMENT_LENGTH_LIMIT: @@ -119,6 +135,9 @@ class ProjectPathAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + outfile_dir = FileDirAction.check_path(parser_in, values, option_string) if not os.path.exists(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} not exists') @@ -141,6 +160,9 @@ class InFileAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + outfile_dir = FileDirAction.check_path(parser_in, values, option_string) if not os.path.exists(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} not exists') @@ -164,6 +186,8 @@ class ModelFileAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + outfile_dir = FileDirAction.check_path(parser_in, values, option_string) if not os.path.exists(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} not exists') @@ -187,6 +211,9 @@ class LogFileAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + outfile_dir = FileDirAction.check_path(parser_in, values, option_string) if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} is not a directory') @@ -206,11 +233,14 @@ class ShapeAction(argparse.Action): values (object): Argument values with type depending on argument definition. option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + in_shape = None shape_str = values shape_list = shape_str.split(':') - if not len(shape_list) == EXPECTED_SHAPE_NUMBER: + if not len(shape_list) == EXPECTED_NUMBER: parser_in.error(f"Only support one shape now, but get {len(shape_list)}.") try: @@ -235,18 +265,26 @@ class NodeAction(argparse.Action): option_string (str): Optional string for specific argument name. Default: None. """ + + ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) + node_str = values if len(node_str) > ARGUMENT_LENGTH_LIMIT: parser_in.error( f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}." ) + node_list = node_str.split(',') + if not len(node_list) == EXPECTED_NUMBER: + parser_in.error(f"Only support one {option_string} now, but get {len(node_list)}.") + setattr(namespace, self.dest, node_str) parser = argparse.ArgumentParser( prog='mindconverter', - description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) + description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__), + allow_abbrev=False) parser.add_argument( '--version', diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index d9a160cb..432f6542 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -42,7 +42,7 @@ BINARY_HEADER_PYTORCH_BITS = 32 ARGUMENT_LENGTH_LIMIT = 512 -EXPECTED_SHAPE_NUMBER = 1 +EXPECTED_NUMBER = 1 @unique