You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_config.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validator for optimizer config."""
  16. import math
  17. from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema
  18. from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
  19. HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
  20. HyperParamKey, SystemDefinedTargets
  21. _BOUND_LEN = 2
  22. _NUMBER_ERR_MSG = "Value(s) should be integer or float."
  23. _TYPE_ERR_MSG = "Value type should be %r."
  24. _VALUE_ERR_MSG = "Value should be in %s. Current value is %s."
  25. def _generate_schema_err_msg(err_msg, *args):
  26. """Organize error messages."""
  27. if args:
  28. err_msg = err_msg % args
  29. return {"invalid": err_msg}
  30. def include_integer(low, high):
  31. """Check if the range [low, high) includes integer."""
  32. def _in_range(num, low, high):
  33. """check if num in [low, high)"""
  34. return low <= num < high
  35. if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high):
  36. return True
  37. return False
  38. class TunerSchema(Schema):
  39. """Schema for tuner."""
  40. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  41. name = fields.Str(required=True,
  42. validate=validate.OneOf(TuneMethod.list_members()),
  43. error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members()))
  44. args = fields.Dict(error_messages=dict_err_msg)
  45. @validates("args")
  46. def check_args(self, data):
  47. """Check args for tuner."""
  48. data_keys = list(data.keys())
  49. support_args = GPSupportArgs.list_members()
  50. if not set(data_keys).issubset(set(support_args)):
  51. raise ValidationError("Only support setting %s for tuner. "
  52. "Current key(s): %s." % (support_args, data_keys))
  53. method = data.get(GPSupportArgs.METHOD.value)
  54. if not isinstance(method, str):
  55. raise ValidationError("The 'method' type should be str.")
  56. if method not in AcquisitionFunctionEnum.list_members():
  57. raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." %
  58. (AcquisitionFunctionEnum.list_members(), method))
  59. class ParameterSchema(Schema):
  60. """Schema for parameter."""
  61. number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG)
  62. list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list")
  63. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  64. bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  65. choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  66. type = fields.Str(error_messages=list_err_msg)
  67. source = fields.Str(error_messages=str_err_msg)
  68. @validates("bounds")
  69. def check_bounds(self, bounds):
  70. """Check if bounds are valid."""
  71. if len(bounds) != _BOUND_LEN:
  72. raise ValidationError("Length of bounds should be %s." % _BOUND_LEN)
  73. if bounds[1] <= bounds[0]:
  74. raise ValidationError("The upper bound must be greater than lower bound. "
  75. "The range is [lower_bound, upper_bound).")
  76. @validates("type")
  77. def check_type(self, type_in):
  78. """Check if type is valid."""
  79. if type_in not in HyperParamType.list_members():
  80. raise ValidationError("The type should be in %s." % HyperParamType.list_members())
  81. @validates("source")
  82. def check_source(self, source):
  83. """Check if source is valid."""
  84. if source not in HyperParamSource.list_members():
  85. raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source))
  86. @validates_schema
  87. def check_combination(self, data, **kwargs):
  88. """check the combination of parameters."""
  89. bound_key = HyperParamKey.BOUND.value
  90. choice_key = HyperParamKey.CHOICE.value
  91. type_key = HyperParamKey.TYPE.value
  92. # check bound and type
  93. bounds = data.get(bound_key)
  94. param_type = data.get(type_key)
  95. if bounds is not None:
  96. if param_type is None:
  97. raise ValidationError("If %r is specified, the %r should be specified also." %
  98. (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value))
  99. if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]):
  100. raise ValidationError("No integer in 'bounds', please modify it.")
  101. # check bound and choice
  102. if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data):
  103. raise ValidationError("Only one of [%r, %r] should be specified." %
  104. (bound_key, choice_key))
  105. class TargetSchema(Schema):
  106. """Schema for target."""
  107. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  108. group = fields.Str(error_messages=str_err_msg)
  109. name = fields.Str(required=True, error_messages=str_err_msg)
  110. goal = fields.Str(error_messages=str_err_msg)
  111. @validates("group")
  112. def check_group(self, group):
  113. """Check if bounds are valid."""
  114. if group not in TargetGroup.list_members():
  115. raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group))
  116. @validates("goal")
  117. def check_goal(self, goal):
  118. """Check if source is valid."""
  119. if goal not in TargetGoal.list_members():
  120. raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal))
  121. @validates_schema
  122. def check_combination(self, data, **kwargs):
  123. """check the combination of parameters."""
  124. if TargetKey.GROUP.value not in data:
  125. # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  126. return
  127. name = data.get(TargetKey.NAME.value)
  128. group = data.get(TargetKey.GROUP.value)
  129. if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members():
  130. raise ValidationError({
  131. TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group})
  132. class OptimizerConfig(Schema):
  133. """Define the search model condition parameter schema."""
  134. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  135. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  136. summary_base_dir = fields.Str(required=True, error_messages=str_err_msg)
  137. command = fields.Str(required=True, error_messages=str_err_msg)
  138. tuner = fields.Dict(required=True, error_messages=dict_err_msg)
  139. target = fields.Dict(required=True, error_messages=dict_err_msg)
  140. parameters = fields.Dict(required=True, error_messages=dict_err_msg)
  141. @validates("tuner")
  142. def check_tuner(self, data):
  143. """Check tuner."""
  144. err = TunerSchema().validate(data)
  145. if err:
  146. raise ValidationError(err)
  147. @validates("parameters")
  148. def check_parameters(self, parameters):
  149. """Check parameters."""
  150. for name, value in parameters.items():
  151. err = ParameterSchema().validate(value)
  152. if err:
  153. raise ValidationError({name: err})
  154. if HyperParamKey.SOURCE.value not in value:
  155. # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  156. continue
  157. source = value.get(HyperParamKey.SOURCE.value)
  158. if source == HyperParamSource.SYSTEM_DEFINED.value and \
  159. name not in TunableSystemDefinedParams.list_members():
  160. raise ValidationError({
  161. name: {"source": "This param is not system defined. Current source is: %s." % source}})
  162. @validates("target")
  163. def check_target(self, target):
  164. """Check target."""
  165. err = TargetSchema().validate(target)
  166. if err:
  167. raise ValidationError(err)