You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_config.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Validator for optimizer config."""
  16. import math
  17. from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema
  18. from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
  19. HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
  20. HyperParamKey, SystemDefinedTargets
  21. from mindinsight.optimizer.utils.utils import is_param_name_valid
  22. _BOUND_LEN = 2
  23. _NUMBER_ERR_MSG = "Value(s) should be integer or float."
  24. _TYPE_ERR_MSG = "Value should be a %s."
  25. _VALUE_ERR_MSG = "Value should be in %s. Current value is %r."
  26. def _generate_schema_err_msg(err_msg, *args):
  27. """Organize error messages."""
  28. if args:
  29. err_msg = err_msg % args
  30. return {"invalid": err_msg}
  31. def _generate_err_msg_for_nested_keys(err_msg, *args):
  32. """Organize error messages for system defined parameters."""
  33. err_dict = {}
  34. for name in args[::-1]:
  35. if not err_dict:
  36. err_dict.update({name: err_msg})
  37. else:
  38. err_dict = {name: err_dict}
  39. return err_dict
  40. def include_integer(low, high):
  41. """Check if the range [low, high) includes integer."""
  42. def _in_range(num, low, high):
  43. """check if num in [low, high)"""
  44. return low <= num < high
  45. if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high):
  46. return True
  47. return False
  48. class TunerSchema(Schema):
  49. """Schema for tuner."""
  50. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  51. name = fields.Str(required=True,
  52. validate=validate.OneOf(TuneMethod.list_members()),
  53. error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members()))
  54. args = fields.Dict(error_messages=dict_err_msg)
  55. @validates("args")
  56. def check_args(self, data):
  57. """Check args for tuner."""
  58. data_keys = list(data.keys())
  59. support_args = GPSupportArgs.list_members()
  60. if not set(data_keys).issubset(set(support_args)):
  61. raise ValidationError("Only support setting %s for tuner. "
  62. "Current key(s): %s." % (support_args, data_keys))
  63. method = data.get(GPSupportArgs.METHOD.value)
  64. if not isinstance(method, str):
  65. raise ValidationError("The 'method' type should be str.")
  66. if method not in AcquisitionFunctionEnum.list_members():
  67. raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." %
  68. (AcquisitionFunctionEnum.list_members(), method))
  69. class ParameterSchema(Schema):
  70. """Schema for parameter."""
  71. number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG)
  72. list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list")
  73. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")
  74. bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  75. choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
  76. type = fields.Str(error_messages=str_err_msg)
  77. source = fields.Str(error_messages=str_err_msg)
  78. @validates("bounds")
  79. def check_bounds(self, bounds):
  80. """Check if bounds are valid."""
  81. if len(bounds) != _BOUND_LEN:
  82. raise ValidationError("Length of bounds should be %s." % _BOUND_LEN)
  83. if bounds[1] <= bounds[0]:
  84. raise ValidationError("The upper bound must be greater than lower bound. "
  85. "The range is [lower_bound, upper_bound).")
  86. @validates("choice")
  87. def check_choice(self, choice):
  88. """Check if choice is valid."""
  89. if not choice:
  90. raise ValidationError("It is empty, please fill in at least one value.")
  91. @validates("type")
  92. def check_type(self, type_in):
  93. """Check if type is valid."""
  94. if type_in not in HyperParamType.list_members():
  95. raise ValidationError("It should be in %s." % HyperParamType.list_members())
  96. @validates("source")
  97. def check_source(self, source):
  98. """Check if source is valid."""
  99. if source not in HyperParamSource.list_members():
  100. raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source))
  101. @validates_schema
  102. def check_combination(self, data, **kwargs):
  103. """check the combination of parameters."""
  104. bound_key = HyperParamKey.BOUND.value
  105. choice_key = HyperParamKey.CHOICE.value
  106. type_key = HyperParamKey.TYPE.value
  107. # check bound and type
  108. bounds = data.get(bound_key)
  109. param_type = data.get(type_key)
  110. if bounds is not None:
  111. if param_type is None:
  112. raise ValidationError("If %r is specified, the %r should be specified also." %
  113. (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value))
  114. if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]):
  115. raise ValidationError("No integer in 'bounds', please modify it.")
  116. # check bound and choice
  117. if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data):
  118. raise ValidationError("Only one of [%r, %r] should be specified." %
  119. (bound_key, choice_key))
  120. class TargetSchema(Schema):
  121. """Schema for target."""
  122. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")
  123. group = fields.Str(error_messages=str_err_msg)
  124. name = fields.Str(required=True, error_messages=str_err_msg)
  125. goal = fields.Str(error_messages=str_err_msg)
  126. @validates("group")
  127. def check_group(self, group):
  128. """Check if bounds are valid."""
  129. if group not in TargetGroup.list_members():
  130. raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group))
  131. @validates("goal")
  132. def check_goal(self, goal):
  133. """Check if source is valid."""
  134. if goal not in TargetGoal.list_members():
  135. raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal))
  136. @validates_schema
  137. def check_combination(self, data, **kwargs):
  138. """check the combination of parameters."""
  139. if TargetKey.GROUP.value not in data:
  140. # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  141. return
  142. name = data.get(TargetKey.NAME.value)
  143. group = data.get(TargetKey.GROUP.value)
  144. if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members():
  145. raise ValidationError({
  146. TargetKey.GROUP.value: "This target is not system defined. Current group is %r." % group})
  147. class OptimizerConfig(Schema):
  148. """Define the search model condition parameter schema."""
  149. dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
  150. str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "string")
  151. summary_base_dir = fields.Str(required=True, error_messages=str_err_msg)
  152. command = fields.Str(required=True, error_messages=str_err_msg)
  153. tuner = fields.Dict(required=True, error_messages=dict_err_msg)
  154. target = fields.Dict(required=True, error_messages=dict_err_msg)
  155. parameters = fields.Dict(required=True, error_messages=dict_err_msg)
  156. def _pre_check_tunable_system_parameters(self, name, value):
  157. self._check_param_type_tunable_system_parameters(name, value)
  158. # need to check param type in choice before checking the value
  159. self._check_param_type_choice_tunable_system_parameters(name, value)
  160. self._check_param_value_tunable_system_parameters(name, value)
  161. def _check_param_type_tunable_system_parameters(self, name, value):
  162. """Check param type for tunable system parameters."""
  163. param_type = value.get(HyperParamKey.TYPE.value)
  164. if param_type is None:
  165. return
  166. if name == TunableSystemDefinedParams.LEARNING_RATE.value:
  167. if param_type != HyperParamType.FLOAT.value:
  168. err_msg = "The value(s) should be float number, " \
  169. "please config its type as %r." % HyperParamType.FLOAT.value
  170. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
  171. elif param_type != HyperParamType.INT.value:
  172. err_msg = "The value(s) should be integer, please config its type as %r." % HyperParamType.INT.value
  173. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.TYPE.value))
  174. def _check_param_type_choice_tunable_system_parameters(self, name, value):
  175. """Check param type in choice for tunable system parameters."""
  176. choice = value.get(HyperParamKey.CHOICE.value)
  177. if choice is None:
  178. return
  179. if name == TunableSystemDefinedParams.LEARNING_RATE.value:
  180. if list(filter(lambda x: not isinstance(x, float), choice)):
  181. err_msg = "The value(s) should be float number."
  182. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  183. elif list(filter(lambda x: isinstance(x, bool) or not isinstance(x, int), choice)):
  184. # isinstance(x, int) will return True if x is bool. use 'type(x)' will not pass lint.
  185. err_msg = "The value(s) should be integer."
  186. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  187. def _check_param_value_tunable_system_parameters(self, name, value):
  188. """Check param value for tunable system parameters."""
  189. bound = value.get(HyperParamKey.BOUND.value)
  190. choice = value.get(HyperParamKey.CHOICE.value)
  191. err_msg = "The value(s) should be positive number."
  192. if bound is not None and bound[0] <= 0:
  193. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
  194. if choice is not None and min(choice) <= 0:
  195. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  196. if name == TunableSystemDefinedParams.LEARNING_RATE.value:
  197. if bound is not None and bound[1] > 1:
  198. err_msg = "The upper bound should be less than and equal to 1."
  199. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.BOUND.value))
  200. if choice is not None and max(choice) >= 1:
  201. err_msg = "The value(s) should be float number less than 1."
  202. raise ValidationError(_generate_err_msg_for_nested_keys(err_msg, name, HyperParamKey.CHOICE.value))
  203. @validates("tuner")
  204. def check_tuner(self, data):
  205. """Check tuner."""
  206. err = TunerSchema().validate(data)
  207. if err:
  208. raise ValidationError(err)
  209. @validates("parameters")
  210. def check_parameters(self, parameters):
  211. """Check parameters."""
  212. for name, value in parameters.items():
  213. if not is_param_name_valid(name):
  214. raise ValidationError("Parameter name %r is not a valid name, only number(0-9), alphabet(a-z, A-Z) "
  215. "and underscore(_) characters are allowed in name." % name)
  216. is_system_param = False
  217. source = value.get(HyperParamKey.SOURCE.value)
  218. if source in [None, HyperParamSource.SYSTEM_DEFINED.value] and \
  219. name in TunableSystemDefinedParams.list_members():
  220. is_system_param = True
  221. if is_system_param:
  222. self._pre_check_tunable_system_parameters(name, value)
  223. err = ParameterSchema().validate(value)
  224. if err:
  225. raise ValidationError({name: err})
  226. if source is None:
  227. # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
  228. continue
  229. if source == HyperParamSource.SYSTEM_DEFINED.value and \
  230. name not in TunableSystemDefinedParams.list_members():
  231. raise ValidationError({
  232. name: {"source": "This param is not system defined. Current source is %r." % source}})
  233. @validates("target")
  234. def check_target(self, target):
  235. """Check target."""
  236. err = TargetSchema().validate(target)
  237. if err:
  238. raise ValidationError(err)