Browse Source

!840 check if the parameter name is valid

Merge pull request !840 from luopengting/fix_optimizer
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b4b18cc96d
2 changed files with 36 additions and 0 deletions
  1. +5
    -0
      mindinsight/optimizer/common/validator/optimizer_config.py
  2. +31
    -0
      mindinsight/optimizer/utils/utils.py

+ 5
- 0
mindinsight/optimizer/common/validator/optimizer_config.py View File

@@ -20,6 +20,7 @@ from marshmallow import Schema, fields, ValidationError, validates, validate, va
from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
HyperParamKey, SystemDefinedTargets
from mindinsight.optimizer.utils.utils import is_param_name_valid

_BOUND_LEN = 2
_NUMBER_ERR_MSG = "Value(s) should be integer or float."
@@ -230,6 +231,10 @@ class OptimizerConfig(Schema):
def check_parameters(self, parameters):
"""Check parameters."""
for name, value in parameters.items():
if not is_param_name_valid(name):
raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) "
"and underscore(_) characters are allowed in name." % name)

err = ParameterSchema().validate(value)
if err:
raise ValidationError({name: err})


+ 31
- 0
mindinsight/optimizer/utils/utils.py View File

@@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Utils for optimizer."""
import string

import numpy as np

_DEFAULT_HISTOGRAM_BINS = 5
@@ -83,3 +85,32 @@ def get_nested_message(info: dict, out_err_msg=""):
else:
out_err_msg = key
return get_nested_message(info[key], out_err_msg)


def is_number(uchar):
"""If it is a number, return True."""
if uchar in string.digits:
return True
return False


def is_alphabet(uchar):
"""If it is a alphabet, return True."""
if uchar in string.ascii_letters:
return True
return False


def is_allowed_symbols(uchar):
"""If it is a allowed symbol, return True."""
if uchar in ['_']:
return True
return False


def is_param_name_valid(param_name: str):
"""If parameter name only contains number or alphabet."""
for uchar in param_name:
if not is_number(uchar) and not is_alphabet(uchar) and not is_allowed_symbols(uchar):
return False
return True

Loading…
Cancel
Save