You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_config.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validator for optimizer config."""
  16. import math
  17. from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema
  18. from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
  19. HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
  20. HyperParamKey, SystemDefinedTargets
  21. from mindinsight.optimizer.utils.utils import is_param_name_valid
  22. _BOUND_LEN = 2
  23. _NUMBER_ERR_MSG = "Value(s) should be integer or float."
  24. _TYPE_ERR_MSG = "Value type should be %r."
  25. _VALUE_ERR_MSG = "Value should be in %s. Current value is %s."
  26. def _generate_schema_err_msg(err_msg, *args):
  27. """Organize error messages."""
  28. if args:
  29. err_msg = err_msg % args
  30. return {"invalid": err_msg}
  31. def _generate_err_msg_for_nested_keys(err_msg, *args):
  32. """Organize error messages for system defined parameters."""
  33. err_dict = {}
  34. for name in args[::-1]:
  35. if not err_dict:
  36. err_dict.update({name: err_msg})
  37. else:
  38. err_dict = {name: err_dict}
  39. return err_dict
  40. def include_integer(low, high):
  41. """Check if the range [low, high) includes integer."""
  42. def _in_range(num, low, high):
  43. """check if num in [low, high)"""
  44. return low <= num < high
  45. if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high):
  46. return True
  47. return False
  48. class TunerSchema(Schema):
  49. """Schema for tuner."""
  50. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  51. name = fields.Str(required=True,
  52. validate=validate.OneOf(TuneMethod.list_members()),
  53. error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members()))
  54. args = fields.Dict(error_messages=dict_err_msg)
  55. @validates("args")
  56. def check_args(self, data):
  57. """Check args for tuner."""
  58. data_keys = list(data.keys())
  59. support_args = GPSupportArgs.list_members()
  60. if not set(data_keys).issubset(set(support_args)):
  61. raise ValidationError("Only support setting %s for tuner. "
  62. "Current key(s): %s." % (support_args, data_keys))
  63. method = data.get(GPSupportArgs.METHOD.value)
  64. if not isinstance(method, str):
  65. raise ValidationError("The 'method' type should be str.")
  66. if method not in AcquisitionFunctionEnum.list_members():
  67. raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." %
  68. (AcquisitionFunctionEnum.list_members(), method))
  69. class ParameterSchema(Schema):
  70. """Schema for parameter."""
  71. number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG)
  72. list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list")
  73. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  74. bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  75. choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  76. type = fields.Str(error_messages=str_err_msg)
  77. source = fields.Str(error_messages=str_err_msg)
  78. @validates("bounds")
  79. def check_bounds(self, bounds):
  80. """Check if bounds are valid."""
  81. if len(bounds) != _BOUND_LEN:
  82. raise ValidationError("Length of bounds should be %s." % _BOUND_LEN)
  83. if bounds[1] <= bounds[0]:
  84. raise ValidationError("The upper bound must be greater than lower bound. "
  85. "The range is [lower_bound, upper_bound).")
  86. @validates("choice")
  87. def check_choice(self, choice):
  88. """Check if choice is valid."""
  89. if not choice:
  90. raise ValidationError("It is empty, please fill in at least one value.")
  91. @validates("type")
  92. def check_type(self, type_in):
  93. """Check if type is valid."""
  94. if type_in not in HyperParamType.list_members():
  95. raise ValidationError("It should be in %s." % HyperParamType.list_members())
  96. @validates("source")
  97. def check_source(self, source):
  98. """Check if source is valid."""
  99. if source not in HyperParamSource.list_members():
  100. raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source))
  101. @validates_schema
  102. def check_combination(self, data, **kwargs):
  103. """check the combination of parameters."""
  104. bound_key = HyperParamKey.BOUND.value
  105. choice_key = HyperParamKey.CHOICE.value
  106. type_key = HyperParamKey.TYPE.value
  107. # check bound and type
  108. bounds = data.get(bound_key)
  109. param_type = data.get(type_key)
  110. if bounds is not None:
  111. if param_type is None:
  112. raise ValidationError("If %r is specified, the %r should be specified also." %
  113. (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value))
  114. if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]):
  115. raise ValidationError("No integer in 'bounds', please modify it.")
  116. # check bound and choice
  117. if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data):
  118. raise ValidationError("Only one of [%r, %r] should be specified." %
  119. (bound_key, choice_key))
  120. class TargetSchema(Schema):
  121. """Schema for target."""
  122. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  123. group = fields.Str(error_messages=str_err_msg)
  124. name = fields.Str(required=True, error_messages=str_err_msg)
  125. goal = fields.Str(error_messages=str_err_msg)
  126. @validates("group")
  127. def check_group(self, group):
  128. """Check if bounds are valid."""
  129. if group not in TargetGroup.list_members():
  130. raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group))
  131. @validates("goal")
  132. def check_goal(self, goal):
  133. """Check if source is valid."""
  134. if goal not in TargetGoal.list_members():
  135. raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal))
  136. @validates_schema
  137. def check_combination(self, data, **kwargs):
  138. """check the combination of parameters."""
  139. if TargetKey.GROUP.value not in data:
  140. # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  141. return
  142. name = data.get(TargetKey.NAME.value)
  143. group = data.get(TargetKey.GROUP.value)
  144. if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members():
  145. raise ValidationError({
  146. TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group})
  147. class OptimizerConfig(Schema):
  148. """Define the search model condition parameter schema."""
  149. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  150. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")
  151. summary_base_dir = fields.Str(required=True, error_messages=str_err_msg)
  152. command = fields.Str(required=True, error_messages=str_err_msg)
  153. tuner = fields.Dict(required=True, error_messages=dict_err_msg)
  154. target = fields.Dict(required=True, error_messages=dict_err_msg)
  155. parameters = fields.Dict(required=True, error_messages=dict_err_msg)
  156. def _check_tunable_system_parameters(self, name, value):
  157. """Check tunable system parameters."""
  158. bound = value.get(HyperParamKey.BOUND.value)
  159. choice = value.get(HyperParamKey.CHOICE.value)
  160. param_type = value.get(HyperParamKey.TYPE.value)
  161. err_msg = "The value(s) should be positive number."
  162. if bound is not None and bound[0] <= 0:
  163. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
  164. if choice is not None and min(choice) <= 0:
  165. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  166. if name == TunableSystemDefinedParams.LEARNING_RATE.value:
  167. if bound is not None and bound[1] > 1:
  168. err_msg = "The upper bound should be less than and equal to 1."
  169. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
  170. if choice is not None and max(choice) >= 1:
  171. err_msg = "The values should be float number less than to 1."
  172. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  173. if param_type == HyperParamType.INT.value:
  174. err_msg = "The value(s) should be float number, please config it as %s." % HyperParamType.FLOAT.value
  175. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
  176. else:
  177. if choice is not None and list(filter(lambda x: not isinstance(x, int), choice)):
  178. # if the choice contains value(s) which is not integer
  179. err_msg = "The value(s) should be integer."
  180. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  181. if bound is not None and param_type != HyperParamType.INT.value:
  182. # if bound is configured, need to config its type as int.
  183. err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
  184. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
  185. @validates("tuner")
  186. def check_tuner(self, data):
  187. """Check tuner."""
  188. err = TunerSchema().validate(data)
  189. if err:
  190. raise ValidationError(err)
  191. @validates("parameters")
  192. def check_parameters(self, parameters):
  193. """Check parameters."""
  194. for name, value in parameters.items():
  195. if not is_param_name_valid(name):
  196. raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) "
  197. "and underscore(_) characters are allowed in name." % name)
  198. err = ParameterSchema().validate(value)
  199. if err:
  200. raise ValidationError({name: err})
  201. source = value.get(HyperParamKey.SOURCE.value)
  202. if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \
  203. name in TunableSystemDefinedParams.list_members():
  204. self._check_tunable_system_parameters(name, value)
  205. if source is None:
  206. # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  207. continue
  208. if source == HyperParamSource.SYSTEM_DEFINED.value and \
  209. name not in TunableSystemDefinedParams.list_members():
  210. raise ValidationError({
  211. name: {"source": "This param is not system defined. Current source is: %s." % source}})
  212. @validates("target")
  213. def check_target(self, target):
  214. """Check target."""
  215. err = TargetSchema().validate(target)
  216. if err:
  217. raise ValidationError(err)