Browse Source

!1222 Add args validation of mindconverter cli

From: @liuchongming74
Reviewed-by: @yelihua,@ouwenchang
Signed-off-by: @ouwenchang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
7a117aabb4
2 changed files with 10 additions and 3 deletions
  1. +9
    -1
      mindinsight/mindconverter/cli.py
  2. +1
    -2
      mindinsight/mindconverter/graph_based_converter/framework.py

+ 9
- 1
mindinsight/mindconverter/cli.py View File

@@ -21,7 +21,7 @@ import mindinsight
from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, \ from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, \
ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType
ARGUMENT_NUM_LIMIT, ARGUMENT_LEN_LIMIT, FrameworkType
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter


from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
@@ -283,11 +283,19 @@ class NodeAction(argparse.Action):
ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in) ArgsCheck.check_repeated(namespace, self.dest, self.default, option_string, parser_in)
if len(values) > ARGUMENT_NUM_LIMIT: if len(values) > ARGUMENT_NUM_LIMIT:
parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.") parser_in.error(f"The length of {option_string} {values} should be no more than {ARGUMENT_NUM_LIMIT}.")
deduplicated = set()
abnormal_nodes = []
for v in values: for v in values:
if len(v) > ARGUMENT_LENGTH_LIMIT: if len(v) > ARGUMENT_LENGTH_LIMIT:
parser_in.error( parser_in.error(
f"The length of {option_string} {v} should be no more than {ARGUMENT_LENGTH_LIMIT}." f"The length of {option_string} {v} should be no more than {ARGUMENT_LENGTH_LIMIT}."
) )
if v in deduplicated:
abnormal_nodes.append(v)
continue
deduplicated.add(v)
if abnormal_nodes:
parser_in.error(f"{', '.join(abnormal_nodes)} {'is' if len(abnormal_nodes) == 1 else 'are'} duplicated.")
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)






+ 1
- 2
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -281,8 +281,7 @@ def main_graph_base_converter(file_config):
check_params = ['input_nodes', 'output_nodes'] check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config) check_params_exist(check_params, file_config)


if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len(
set(file_config.get("input_nodes", []))):
if len(file_config['shape']) != len(file_config.get("input_nodes", [])):
raise BadParamError("`--shape` and `--input_nodes` must have the same length, " raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
"and no redundant node in `--input_nodes`.") "and no redundant node in `--input_nodes`.")




Loading…
Cancel
Save