You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

param_handler.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils for params."""
  16. import math
  17. import numpy as np
  18. from mindinsight.lineagemgr.model import LineageTable, USER_DEFINED_PREFIX, METRIC_PREFIX
  19. from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType, HyperParamSource, TargetKey, \
  20. TargetGoal, TunableSystemDefinedParams, TargetGroup, SystemDefinedTargets
  21. from mindinsight.optimizer.common.log import logger
  22. def generate_param(param_info, n=1):
  23. """Generate param."""
  24. value = None
  25. if HyperParamKey.BOUND.value in param_info:
  26. bound = param_info[HyperParamKey.BOUND.value]
  27. value = np.random.uniform(bound[0], bound[1], n)
  28. if param_info[HyperParamKey.TYPE.value] == HyperParamType.INT.value:
  29. value = value.astype(HyperParamType.INT.value)
  30. if HyperParamKey.CHOICE.value in param_info:
  31. indexes = np.random.randint(0, len(param_info[HyperParamKey.CHOICE.value]), n)
  32. value = [param_info[HyperParamKey.CHOICE.value][index] for index in indexes]
  33. if HyperParamKey.DECIMAL.value in param_info:
  34. value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value])
  35. return np.array(value)
  36. def generate_arrays(params_info: dict, n=1):
  37. """Generate arrays."""
  38. suggest_params = None
  39. for _, param_info in params_info.items():
  40. suggest_param = generate_param(param_info, n).reshape((-1, 1))
  41. if suggest_params is None:
  42. suggest_params = suggest_param
  43. else:
  44. suggest_params = np.hstack((suggest_params, suggest_param))
  45. if n == 1:
  46. return suggest_params[0]
  47. return suggest_params
  48. def match_value_type(array, params_info: dict):
  49. """Make array match params type."""
  50. array_new = []
  51. index = 0
  52. for _, param_info in params_info.items():
  53. value = array[index]
  54. bound = param_info.get(HyperParamKey.BOUND.value)
  55. choice = param_info.get(HyperParamKey.CHOICE.value)
  56. if bound is not None:
  57. value = max(bound[0], array[index])
  58. value = min(bound[1], value)
  59. if choice is not None:
  60. nearest_index = int(np.argmin(np.fabs(np.array(choice) - value)))
  61. value = choice[nearest_index]
  62. if param_info.get(HyperParamKey.TYPE.value) == HyperParamType.INT.value:
  63. value = int(value)
  64. if bound is not None and value < bound[0]:
  65. value = math.ceil(bound[0])
  66. elif bound is not None and value >= bound[1]:
  67. # bound[1] is 2.0, value is 1; bound[1] is 2.1, value is 2
  68. value = math.floor(bound[1]) - 1
  69. if HyperParamKey.DECIMAL.value in param_info:
  70. value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value])
  71. array_new.append(value)
  72. index += 1
  73. return array_new
  74. def organize_params_target(lineage_table: LineageTable, params_info: dict, target_info):
  75. """Organize params and target."""
  76. empty_result = np.array([])
  77. if lineage_table is None:
  78. return empty_result, empty_result
  79. param_keys = []
  80. for param_key, param_info in params_info.items():
  81. # It will be a user_defined param:
  82. # 1. if 'source' is specified as 'user_defined'
  83. # 2. if 'source' is not specified and the param is not a system_defined key
  84. source = param_info.get(HyperParamKey.SOURCE.value)
  85. prefix = _get_prefix(param_key, source, HyperParamSource.USER_DEFINED.value,
  86. USER_DEFINED_PREFIX, TunableSystemDefinedParams.list_members())
  87. param_key = f'{prefix}{param_key}'
  88. if prefix == USER_DEFINED_PREFIX:
  89. param_info[HyperParamKey.SOURCE.value] = HyperParamSource.USER_DEFINED.value
  90. else:
  91. param_info[HyperParamKey.SOURCE.value] = HyperParamSource.SYSTEM_DEFINED.value
  92. param_keys.append(param_key)
  93. target_name = target_info[TargetKey.NAME.value]
  94. group = target_info.get(TargetKey.GROUP.value)
  95. prefix = _get_prefix(target_name, group, TargetGroup.METRIC.value,
  96. METRIC_PREFIX, SystemDefinedTargets.list_members())
  97. target_name = prefix + target_name
  98. lineage_df = lineage_table.dataframe_data
  99. try:
  100. lineage_df = lineage_df[param_keys + [target_name]]
  101. lineage_df = lineage_df.dropna(axis=0, how='any')
  102. target_column = np.array(lineage_df[target_name])
  103. if TargetKey.GOAL.value in target_info and \
  104. target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value:
  105. target_column = -target_column
  106. return np.array(lineage_df[param_keys]), target_column
  107. except KeyError as exc:
  108. logger.warning("Some keys not exist in specified params or target. It will suggest params randomly."
  109. "Detail: %s.", str(exc))
  110. return empty_result, empty_result
  111. def _get_prefix(name, field, other_defined_field, other_defined_prefix, system_defined_fields):
  112. if (field == other_defined_field) or (field is None and name not in system_defined_fields):
  113. return other_defined_prefix
  114. return ''