diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 08881c76..a5833186 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -21,7 +21,7 @@ import mindinsight from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, \ -ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType + ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console @@ -283,11 +283,19 @@ class NodeAction(argparse.Action): ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) if len(values) > ARGUMENT_NUM_LIMIT: parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.") + deduplicated = set() + abnormal_nodes = [] for v in values: if len(v) > ARGUMENT_LENGTH_LIMIT: parser_in.error( f"The length of {option_string} {v} should be no more than {ARGUMENT_LENGTH_LIMIT}." ) + if v in deduplicated: + abnormal_nodes.append(v) + continue + deduplicated.add(v) + if abnormal_nodes: + parser_in.error(f"{', '.join(abnormal_nodes)} {'is' if len(abnormal_nodes) == 1 else 'are'} duplicated.") setattr(namespace, self.dest, values) diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 8438c404..1422b5b3 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -281,8 +281,7 @@ def main_graph_base_converter(file_config): check_params = ['input_nodes', 'output_nodes'] check_params_exist(check_params, file_config) - if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len( - set(file_config.get("input_nodes", []))): + if len(file_config['shape']) != len(file_config.get("input_nodes", [])): raise BadParamError("`--shape` and `--input_nodes` must have the same length, " "and no redundant node in `--input_nodes`.")