Browse Source

!840 check if the parameter name is valid

Merge pull request !840 from luopengting/fix_optimizer
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b4b18cc96d
2 changed files with 36 additions and 0 deletions
  1. +5
    -0
      mindinsight/optimizer/common/validator/optimizer_config.py
  2. +31
    -0
      mindinsight/optimizer/utils/utils.py

+ 5
- 0
mindinsight/optimizer/common/validator/optimizer_config.py View File

@@ -20,6 +20,7 @@ from marshmallow import Schema, fields, ValidationError, validates, validate, va
from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \ from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \ HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
HyperParamKey, SystemDefinedTargets HyperParamKey, SystemDefinedTargets
from mindinsight.optimizer.utils.utils import is_param_name_valid


_BOUND_LEN = 2 _BOUND_LEN = 2
_NUMBER_ERR_MSG = "Value(s) should be integer or float." _NUMBER_ERR_MSG = "Value(s) should be integer or float."
@@ -230,6 +231,10 @@ class OptimizerConfig(Schema):
def check_parameters(self, parameters): def check_parameters(self, parameters):
"""Check parameters.""" """Check parameters."""
for name, value in parameters.items(): for name, value in parameters.items():
if not is_param_name_valid(name):
raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) "
"and underscore(_) characters are allowed in name." % name)

err = ParameterSchema().validate(value) err = ParameterSchema().validate(value)
if err: if err:
raise ValidationError({name: err}) raise ValidationError({name: err})


+ 31
- 0
mindinsight/optimizer/utils/utils.py View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Utils for optimizer.""" """Utils for optimizer."""
import string

import numpy as np import numpy as np


_DEFAULT_HISTOGRAM_BINS = 5 _DEFAULT_HISTOGRAM_BINS = 5
@@ -83,3 +85,32 @@ def get_nested_message(info: dict, out_err_msg=""):
else: else:
out_err_msg = key out_err_msg = key
return get_nested_message(info[key], out_err_msg) return get_nested_message(info[key], out_err_msg)


def is_number(uchar):
"""If it is a number, return True."""
if uchar in string.digits:
return True
return False


def is_alphabet(uchar):
"""If it is a alphabet, return True."""
if uchar in string.ascii_letters:
return True
return False


def is_allowed_symbols(uchar):
"""If it is a allowed symbol, return True."""
if uchar in ['_']:
return True
return False


def is_param_name_valid(param_name: str):
"""If parameter name only contains number or alphabet."""
for uchar in param_name:
if not is_number(uchar) and not is_alphabet(uchar) and not is_allowed_symbols(uchar):
return False
return True

Loading…
Cancel
Save