diff --git a/mindinsight/optimizer/common/validator/optimizer_config.py b/mindinsight/optimizer/common/validator/optimizer_config.py index 2683e1a5..632ff767 100644 --- a/mindinsight/optimizer/common/validator/optimizer_config.py +++ b/mindinsight/optimizer/common/validator/optimizer_config.py @@ -20,6 +20,7 @@ from marshmallow import Schema, fields, ValidationError, validates, validate, va from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \ HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \ HyperParamKey, SystemDefinedTargets +from mindinsight.optimizer.utils.utils import is_param_name_valid _BOUND_LEN = 2 _NUMBER_ERR_MSG = "Value(s) should be integer or float." @@ -230,6 +231,10 @@ class OptimizerConfig(Schema): def check_parameters(self, parameters): """Check parameters.""" for name, value in parameters.items(): + if not is_param_name_valid(name): + raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) " + "and underscore(_) characters are allowed in name." % name) + err = ParameterSchema().validate(value) if err: raise ValidationError({name: err}) diff --git a/mindinsight/optimizer/utils/utils.py b/mindinsight/optimizer/utils/utils.py index a9cacbb2..ddcfe3ea 100644 --- a/mindinsight/optimizer/utils/utils.py +++ b/mindinsight/optimizer/utils/utils.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Utils for optimizer.""" +import string + import numpy as np _DEFAULT_HISTOGRAM_BINS = 5 @@ -83,3 +85,32 @@ def get_nested_message(info: dict, out_err_msg=""): else: out_err_msg = key return get_nested_message(info[key], out_err_msg) + + +def is_number(uchar): + """If it is a number, return True.""" + if uchar in string.digits: + return True + return False + + +def is_alphabet(uchar): + """If it is a alphabet, return True.""" + if uchar in string.ascii_letters: + return True + return False + + +def is_allowed_symbols(uchar): + """If it is a allowed symbol, return True.""" + if uchar in ['_']: + return True + return False + + +def is_param_name_valid(param_name: str): + """If parameter name only contains number or alphabet.""" + for uchar in param_name: + if not is_number(uchar) and not is_alphabet(uchar) and not is_allowed_symbols(uchar): + return False + return True