You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tuner.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """General tuner."""
  16. import json
  17. import os
  18. import shlex
  19. import subprocess
  20. import uuid
  21. import yaml
  22. from marshmallow import ValidationError
  23. from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
  24. from mindinsight.lineagemgr.model import get_lineage_table, LineageTable
  25. from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
  26. from mindinsight.optimizer.common.enums import TuneMethod
  27. from mindinsight.optimizer.common.exceptions import OptimizerTerminateError, ConfigParamError
  28. from mindinsight.optimizer.common.log import logger
  29. from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig
  30. from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner
  31. from mindinsight.optimizer.utils.param_handler import organize_params_target
  32. from mindinsight.optimizer.utils.utils import get_nested_message
  33. from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError
  34. _OK = 0
  35. class Tuner:
  36. """
  37. Tuner for auto tuning.
  38. Args:
  39. config_path (str): config path, a yaml format file containing settings about tuner, target and parameters, etc.
  40. Raises:
  41. FileSystemPermissionError, can not open the config file because of permission.
  42. UnknownError, other exception.
  43. """
  44. def __init__(self, config_path: str):
  45. self._config_info = self._validate_config(config_path)
  46. self._summary_base_dir = self._config_info.get('summary_base_dir')
  47. self._dir_prefix = 'train'
  48. def _validate_config(self, config_path):
  49. """Check config_path."""
  50. config_path = self._normalize_path("config_path", config_path)
  51. try:
  52. with open(config_path, "r") as file:
  53. config_info = yaml.safe_load(file)
  54. except PermissionError as exc:
  55. raise FileSystemPermissionError("Can not open config file. Detail: %s." % str(exc))
  56. except Exception as exc:
  57. raise UnknownError("Detail: %s." % str(exc))
  58. self._validate_config_schema(config_info)
  59. config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir'))
  60. self._make_summary_base_dir(config_info['summary_base_dir'])
  61. return config_info
  62. def _validate_config_schema(self, config_info):
  63. error = OptimizerConfig().validate(config_info)
  64. if error:
  65. err_msg = get_nested_message(error)
  66. raise ConfigParamError(err_msg)
  67. def _make_summary_base_dir(self, summary_base_dir):
  68. """Check and make summary_base_dir."""
  69. if not os.path.exists(summary_base_dir):
  70. permissions = os.R_OK | os.W_OK | os.X_OK
  71. os.umask(permissions << 3 | permissions)
  72. mode = permissions << 6
  73. try:
  74. logger.info("The summary_base_dir is generated automatically, path is %s.", summary_base_dir)
  75. os.makedirs(summary_base_dir, mode=mode, exist_ok=True)
  76. except OSError as exc:
  77. raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc))
  78. def _normalize_path(self, param_name, path):
  79. """Normalize config path."""
  80. path = os.path.realpath(path)
  81. try:
  82. path = safe_normalize_path(
  83. path, param_name, None, check_absolute_path=True
  84. )
  85. except ValidationError:
  86. logger.error("The %r is invalid.", param_name)
  87. raise ParamValueError("The %r is invalid." % param_name)
  88. return path
  89. def _update_from_lineage(self):
  90. """Update lineage from lineagemgr."""
  91. try:
  92. lineage_table = get_lineage_table(summary_base_dir=self._summary_base_dir)
  93. except MindInsightException as err:
  94. logger.info("Can not query lineage. Detail: %s", str(err))
  95. lineage_table = None
  96. self._lineage_table = lineage_table
  97. def optimize(self, max_expr_times=1):
  98. """Method for auto tuning."""
  99. target_info = self._config_info.get('target')
  100. params_info = self._config_info.get('parameters')
  101. command = self._config_info.get('command')
  102. tuner = self._config_info.get('tuner')
  103. for _ in range(max_expr_times):
  104. self._update_from_lineage()
  105. suggestion, user_defined_info = self._suggest(self._lineage_table, params_info, target_info, tuner)
  106. hyper_config = {
  107. 'params': suggestion,
  108. 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}'),
  109. 'custom_lineage_data': user_defined_info
  110. }
  111. logger.info("Suggest values are: %s.", suggestion)
  112. os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config)
  113. s = subprocess.Popen(shlex.split(command))
  114. s.wait()
  115. if s.returncode != _OK:
  116. logger.error("An error occurred during execution, the auto tuning will be terminated.")
  117. raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.")
  118. def _get_tuner(self, tuner):
  119. """Get tuner."""
  120. if tuner is None:
  121. return GPBaseTuner()
  122. tuner_name = tuner.get("name").lower()
  123. if tuner_name not in TuneMethod.list_members():
  124. raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members())
  125. args = tuner.get("args")
  126. if args is not None and args.get("method") is not None:
  127. return GPBaseTuner(args.get("method"))
  128. # Only support gaussian process regressor currently.
  129. return GPBaseTuner()
  130. def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, tuner):
  131. """Get suggestions for targets."""
  132. tuner = self._get_tuner(tuner)
  133. param_matrix, target_column = organize_params_target(lineage_table, params_info, target_info)
  134. suggestion, user_defined_info = tuner.suggest(param_matrix, target_column, params_info)
  135. return suggestion, user_defined_info