diff --git a/mindinsight/backend/optimizer/optimizer_api.py b/mindinsight/backend/optimizer/optimizer_api.py index cf82f3ca..fecd13ea 100644 --- a/mindinsight/backend/optimizer/optimizer_api.py +++ b/mindinsight/backend/optimizer/optimizer_api.py @@ -14,12 +14,11 @@ # ============================================================================ """Optimizer API module.""" import json -import pandas as pd from flask import Blueprint, jsonify, request from mindinsight.conf import settings from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER -from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable +from mindinsight.lineagemgr.model import get_lineage_table from mindinsight.optimizer.common.enums import ReasonCode from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError from mindinsight.optimizer.utils.importances import calc_hyper_param_importance @@ -44,8 +43,7 @@ def get_optimize_targets(): def _get_optimize_targets(data_manager, search_condition=None): """Get optimize targets.""" - flatten_lineage = get_flattened_lineage(data_manager, search_condition) - table = LineageTable(pd.DataFrame(flatten_lineage)) + table = get_lineage_table(data_manager, search_condition) target_summaries = [] for target in table.target_names: diff --git a/mindinsight/lineagemgr/common/validator/model_parameter.py b/mindinsight/lineagemgr/common/validator/model_parameter.py index aff90909..19af0bdc 100644 --- a/mindinsight/lineagemgr/common/validator/model_parameter.py +++ b/mindinsight/lineagemgr/common/validator/model_parameter.py @@ -179,8 +179,8 @@ class SearchModelConditionParameter(Schema): raise ValidationError("Given lineage type should be one of %s." % lineage_types) @pre_load - def check_comparision(self, data, **kwargs): - """Check comparision for all parameters in schema.""" + def check_comparison(self, data, **kwargs): + """Check comparison for all parameters in schema.""" for attr, condition in data.items(): if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']: continue diff --git a/mindinsight/lineagemgr/model.py b/mindinsight/lineagemgr/model.py index e8eec69b..7f34b070 100644 --- a/mindinsight/lineagemgr/model.py +++ b/mindinsight/lineagemgr/model.py @@ -27,8 +27,8 @@ from mindinsight.optimizer.common.enums import ReasonCode from mindinsight.optimizer.utils.utils import is_simple_numpy_number from mindinsight.utils.exceptions import MindInsightException -_METRIC_PREFIX = "[M]" -_USER_DEFINED_PREFIX = "[U]" +METRIC_PREFIX = "[M]" +USER_DEFINED_PREFIX = "[U]" USER_DEFINED_INFO_LIMIT = 100 @@ -85,7 +85,7 @@ def get_flattened_lineage(data_manager, search_condition=None): for index, lineage in enumerate(lineages): flatten_dict['train_id'].append(lineage.get("summary_dir")) for key, val in _flatten_lineage(lineage.get('model_lineage', {})): - if key.startswith(_USER_DEFINED_PREFIX) and key not in flatten_dict: + if key.startswith(USER_DEFINED_PREFIX) and key not in flatten_dict: if user_count > USER_DEFINED_INFO_LIMIT: log.warning("The user_defined_info has reached the limit %s. %r is ignored", USER_DEFINED_INFO_LIMIT, key) @@ -105,10 +105,10 @@ def _flatten_lineage(lineage): for key, val in lineage.items(): if key == 'metric': for k, v in val.items(): - yield f'{_METRIC_PREFIX}{k}', v + yield f'{METRIC_PREFIX}{k}', v elif key == 'user_defined': for k, v in val.items(): - yield f'{_USER_DEFINED_PREFIX}{k}', v + yield f'{USER_DEFINED_PREFIX}{k}', v else: yield key, val @@ -144,7 +144,7 @@ class LineageTable: self._df = self._df.drop(columns=columns_to_drop) for name in columns_to_drop: - if not name.startswith(_USER_DEFINED_PREFIX): + if not name.startswith(USER_DEFINED_PREFIX): continue self._drop_columns_info.append({ "name": name, @@ -155,7 +155,7 @@ class LineageTable: @property def target_names(self): """Get names for optimize targets (eg loss, accuracy).""" - target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)] + target_names = [name for name in self._df.columns if name.startswith(METRIC_PREFIX)] if self._LOSS_NAME in self._df.columns: target_names.append(self._LOSS_NAME) return target_names @@ -167,7 +167,7 @@ class LineageTable: hyper_param_names = [ name for name in self._df.columns - if not name.startswith(_METRIC_PREFIX) and name not in blocked_names] + if not name.startswith(METRIC_PREFIX) and name not in blocked_names] if self._LOSS_NAME in hyper_param_names: hyper_param_names.remove(self._LOSS_NAME) @@ -184,7 +184,7 @@ class LineageTable: @property def user_defined_hyper_param_names(self): """Get user defined hyper param names.""" - names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)] + names = [name for name in self._df.columns if name.startswith(USER_DEFINED_PREFIX)] return names def get_column(self, name): @@ -220,3 +220,11 @@ class LineageTable: def drop_column_info(self): """Get dropped columns info.""" return self._drop_columns_info + + +def get_lineage_table(data_manager, search_condition=None): + """Get lineage table from data_manager.""" + lineage_table = get_flattened_lineage(data_manager, search_condition) + lineage_table = LineageTable(pd.DataFrame(lineage_table)) + + return lineage_table diff --git a/mindinsight/optimizer/__init__.py b/mindinsight/optimizer/__init__.py index 6ccbf1a4..6db9025d 100644 --- a/mindinsight/optimizer/__init__.py +++ b/mindinsight/optimizer/__init__.py @@ -17,3 +17,5 @@ Optimizer. Optimizer provides optimization target distribution, parameter importance, etc. """ + +from mindinsight.optimizer.hyper_config import HyperConfig diff --git a/mindinsight/optimizer/common/constants.py b/mindinsight/optimizer/common/constants.py new file mode 100644 index 00000000..2241626d --- /dev/null +++ b/mindinsight/optimizer/common/constants.py @@ -0,0 +1,16 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Common constants for optimizer.""" +HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" diff --git a/mindinsight/optimizer/common/enums.py b/mindinsight/optimizer/common/enums.py index 8e1c814a..adb5216a 100644 --- a/mindinsight/optimizer/common/enums.py +++ b/mindinsight/optimizer/common/enums.py @@ -30,3 +30,46 @@ class ReasonCode(BaseEnum): NOT_ALL_NUMBERS = 1 SAMPLES_NOT_ENOUGH = 2 CORRELATION_NAN = 3 + + +class AcquisitionFunctionEnum(BaseEnum): + """Enum for acquisition function method.""" + # Upper confidence bound + UCB = 'ucb' + # Probability of improvement + PI = 'pi' + # Expected improvement + EI = 'ei' + + +class TuneMethod(BaseEnum): + """Enum for tuning method.""" + # Gaussian process regressor + GP = 'gp' + + +class HyperParamKey(BaseEnum): + """Config keys for hyper parameters.""" + BOUND = 'bounds' + CHOICE = 'choice' + DECIMAL = 'decimal' + TYPE = 'type' + + +class HyperParamType(BaseEnum): + """Config keys for hyper parameters.""" + INT = 'int' + FLOAT = 'float' + + +class TargetKey(BaseEnum): + """Config keys for target.""" + GROUP = 'group' + NAME = 'name' + GOAL = 'goal' + + +class TargetGoal(BaseEnum): + """Goal for target.""" + MAXIMUM = 'maximize' + MINIMUM = 'minimize' diff --git a/mindinsight/optimizer/common/exceptions.py b/mindinsight/optimizer/common/exceptions.py index 1668f0cf..bf3606a7 100644 --- a/mindinsight/optimizer/common/exceptions.py +++ b/mindinsight/optimizer/common/exceptions.py @@ -32,3 +32,19 @@ class CorrelationNanError(MindInsightException): super(CorrelationNanError, self).__init__(OptimizerErrors.CORRELATION_NAN, error_msg, http_code=400) + + +class HyperConfigError(MindInsightException): + """Hyper config error.""" + def __init__(self, error_msg="Hyper config is not correct."): + super(HyperConfigError, self).__init__(OptimizerErrors.HYPER_CONFIG_ERROR, + error_msg, + http_code=400) + + +class OptimizerTerminateError(MindInsightException): + """Hyper config error.""" + def __init__(self, error_msg="Auto tuning has been terminated."): + super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, + error_msg, + http_code=400) diff --git a/mindinsight/optimizer/hyper_config.py b/mindinsight/optimizer/hyper_config.py new file mode 100644 index 00000000..5e0ba3b6 --- /dev/null +++ b/mindinsight/optimizer/hyper_config.py @@ -0,0 +1,94 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Hyper config.""" +import json +import os +from attrdict import AttrDict + +from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME +from mindinsight.optimizer.common.exceptions import HyperConfigError + +_HYPER_CONFIG_LEN_LIMIT = 100000 + + +class HyperConfig: + """ + Hyper config. + + Init hyper config: + >>> hyper_config = HyperConfig() + + Get suggest params: + >>> param_obj = hyper_config.params + >>> learning_rate = params.learning_rate + + Get summary dir: + >>> summary_dir = hyper_config.summary_dir + + Record by SummaryCollector: + >>> summary_cb = SummaryCollector(summary_dir) + """ + def __init__(self): + self._init_validate_hyper_config() + + def _init_validate_hyper_config(self): + """Init and validate hyper config.""" + hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) + if hyper_config is None: + raise HyperConfigError("Hyper config is not in system environment.") + if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT: + raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of " + "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) + + try: + hyper_config = json.loads(hyper_config) + except TypeError as exc: + raise HyperConfigError("Hyper config type error. detail: %s." % str(exc)) + except Exception as exc: + raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) + + self._validate_hyper_config(hyper_config) + self._summary_dir = hyper_config.get('summary_dir') + self._param_obj = AttrDict(hyper_config.get('params')) + + def _validate_hyper_config(self, hyper_config): + """Validate hyper config.""" + for key in ['summary_dir', 'params']: + if key not in hyper_config: + raise HyperConfigError("%r must exist in hyper_config." % key) + + # validate summary_dir + summary_dir = hyper_config.get('summary_dir') + if not isinstance(summary_dir, str): + raise HyperConfigError("The 'summary_dir' should be string.") + hyper_config['summary_dir'] = os.path.realpath(summary_dir) + + # validate params + params = hyper_config.get('params') + if not isinstance(params, dict): + raise HyperConfigError("'params' is not a dict.") + for key, value in params.items(): + if not isinstance(value, (int, float)): + raise HyperConfigError("The value of %r is not integer or float." % key) + + @property + def params(self): + """Get params.""" + return self._param_obj + + @property + def summary_dir(self): + """Get train summary dir path.""" + return self._summary_dir diff --git a/mindinsight/optimizer/tuner.py b/mindinsight/optimizer/tuner.py new file mode 100644 index 00000000..116a501a --- /dev/null +++ b/mindinsight/optimizer/tuner.py @@ -0,0 +1,163 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""General tuner.""" +import json +import os +import shlex +import subprocess +import uuid +import yaml + +from marshmallow import ValidationError + +from mindinsight.datavisual.data_transform.data_manager import DataManager +from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater +from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path +from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX +from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME +from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal +from mindinsight.optimizer.common.exceptions import OptimizerTerminateError +from mindinsight.optimizer.common.log import logger +from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner +from mindinsight.optimizer.utils.param_handler import organize_params_target +from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError + +_OK = 0 + + +class Tuner: + """ + Tuner for auto tuning. + + Args: + config_path (str): config path, a yaml format file containing settings about tuner, target and parameters, etc. + + Raises: + FileSystemPermissionError, can not open the config file because of permission. + UnknownError, other exception. + """ + def __init__(self, config_path: str): + self._config_info = self._validate_config(config_path) + self._summary_base_dir = self._config_info.get('summary_base_dir') + self._data_manager = self._init_data_manager() + self._dir_prefix = 'train' + + def _validate_config(self, config_path): + """Check config_path.""" + config_path = self._normalize_path("config_path", config_path) + try: + with open(config_path, "r") as file: + config_info = yaml.safe_load(file) + except PermissionError as exc: + raise FileSystemPermissionError("Can not open config file. Detail: %s." % str(exc)) + except Exception as exc: + raise UnknownError("Detail: %s." % str(exc)) + + # need to add validation for config_info: command, summary_base_dir, target and params. + config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) + self._make_summary_base_dir(config_info['summary_base_dir']) + return config_info + + def _make_summary_base_dir(self, summary_base_dir): + """Check and make summary_base_dir.""" + if not os.path.exists(summary_base_dir): + permissions = os.R_OK | os.W_OK | os.X_OK + os.umask(permissions << 3 | permissions) + mode = permissions << 6 + try: + logger.info("The summary_base_dir is generated automatically, path is %s.", summary_base_dir) + os.makedirs(summary_base_dir, mode=mode, exist_ok=True) + except OSError as exc: + raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) + + def _init_data_manager(self): + """Initialize data_manager.""" + data_manager = DataManager(summary_base_dir=self._summary_base_dir) + data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) + + return data_manager + + def _normalize_path(self, param_name, path): + """Normalize config path.""" + path = os.path.realpath(path) + try: + path = safe_normalize_path( + path, param_name, None, check_absolute_path=True + ) + except ValidationError: + logger.error("The %r is invalid.", param_name) + raise ParamValueError("The %r is invalid." % param_name) + + return path + + def _update_from_lineage(self): + """Update lineage from lineagemgr.""" + self._data_manager.start_load_data(reload_interval=0).join() + + try: + lineage_table = get_lineage_table(self._data_manager) + except MindInsightException as err: + logger.info("Can not query lineage. Detail: %s", str(err)) + lineage_table = None + + self._lineage_table = lineage_table + + def optimize(self, max_expr_times=1): + """Method for auto tuning.""" + target_info = self._config_info.get('target') + params_info = self._config_info.get('parameters') + command = self._config_info.get('command') + tuner = self._config_info.get('tuner') + for _ in range(max_expr_times): + self._update_from_lineage() + suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name")) + + hyper_config = { + 'params': suggestion, + 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}') + } + os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) + s = subprocess.Popen(shlex.split(command)) + s.wait() + if s.returncode != _OK: + logger.error("An error occurred during execution, the auto tuning will be terminated.") + raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") + + def _get_tuner(self, tune_method=TuneMethod.GP.value): + """Get tuner.""" + if tune_method.lower() not in TuneMethod.list_members(): + raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) + + # Only support gaussian process regressor currently. + return GPBaseTuner() + + def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method): + """Get suggestions for targets.""" + tuner = self._get_tuner(method) + target_name = target_info[TargetKey.NAME.value] + if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric': + target_name = METRIC_PREFIX + target_name + param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name) + + if not param_matrix.empty: + suggestion = tuner.suggest([], [], params_info) + else: + target_column = target_matrix[target_name].reshape((-1, 1)) + if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: + target_column = -target_column + + suggestion = tuner.suggest(param_matrix, target_column, params_info) + + return suggestion diff --git a/mindinsight/optimizer/tuners/base_tuner.py b/mindinsight/optimizer/tuners/base_tuner.py new file mode 100644 index 00000000..149b6217 --- /dev/null +++ b/mindinsight/optimizer/tuners/base_tuner.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base tuner.""" +from abc import abstractmethod + + +class BaseTuner: + @abstractmethod + def suggest(self, *args): + """Suggest method should be implemented.""" diff --git a/mindinsight/optimizer/tuners/gp_tuner.py b/mindinsight/optimizer/tuners/gp_tuner.py new file mode 100644 index 00000000..adbf8976 --- /dev/null +++ b/mindinsight/optimizer/tuners/gp_tuner.py @@ -0,0 +1,178 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GP Tuner.""" +import warnings +import numpy as np + +from scipy.stats import norm +from scipy.optimize import minimize +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import Matern + +from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey +from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type +from mindinsight.optimizer.utils.transformer import Transformer +from mindinsight.utils.exceptions import ParamValueError +from mindinsight.optimizer.tuners.base_tuner import BaseTuner + + +class AcquisitionFunction: + """ + It can be seen from the Gaussian process that the probability description of the objective + function can be obtained by sampling. Sampling usually involves two aspects: + - Explore: Explore new spaces, this sampling helps to estimate more accurate results; + - Exploit: Sampling near the existing results (usually near the existing maximum value), + hoping to find larger results. + + The purpose of the acquisition function is to balance these two sampling processes. + Supported acquisition function: + - Probability of improvement. + - Expected improvement. + - Upper confidence bound. The weighted sum of posterior mean and posterior standard deviation. + formula: result = exploitation + βt * exploration, where βt are appropriate constants. + + Args: + method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'. + beta (float): trade-off param for upper confidence bound function. + beta_decay (float): the decay for beta. Formula: beta = beta * beta_decay. + beta_decay_delay (int): if the counter is bigger than beta_decay_delay, the beta begins to decay. + xi (float): trade-off for expected improvement and probability of improvement. + """ + def __init__(self, method: str, beta, xi, beta_decay=1, beta_decay_delay=0): + self._beta = beta + self._beta_decay = beta_decay + self._beta_decay_delay = beta_decay_delay + self._xi = xi + self._method = method.lower() + if self._method not in AcquisitionFunctionEnum.list_members(): + raise ParamValueError(error_detail="The 'method' should be in %s." % AcquisitionFunctionEnum.list_members()) + + self._counter = 0 + + def update(self): + """Update k.""" + self._counter += 1 + + if self._counter > self._beta_decay_delay and self._beta_decay < 1: + self._beta *= self._beta_decay + + def ac(self, x, gp, y_max): + """Acquisition Function.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + mean, std = gp.predict(x, return_std=True) + + if self._method == AcquisitionFunctionEnum.UCB.value: + # Upper confidence bound. + res = mean + self._beta * std + elif self._method == AcquisitionFunctionEnum.EI.value: + # Expected improvement. + u_f = (mean - y_max - self._xi) + z = u_f / std + res = u_f * norm.cdf(z) + std * norm.pdf(z) + else: + # Probability of improvement. + z = (mean - y_max - self._xi) / std + res = norm.cdf(z) + + return res + + +class GPBaseTuner(BaseTuner): + """ + Tuner using gaussian process regressor. + + Args: + method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'. + Detail at AcquisitionFunction. + beta (float): β, trade-off param for upper confidence bound function. + beta_decay (float): the decay for beta. beta = beta * beta_decay. + beta_decay_delay (int): if counter is bigger than beta_decay_delay, the beta begins to decay. + xi (float): ξ, trade-off for expected improvement and probability of improvement. + random_state (np.random.RandomState): if it is None, it will be assigned as RandomState. + """ + def __init__(self, + method=AcquisitionFunctionEnum.UCB.value, + beta=2.576, + beta_decay=1, + beta_decay_delay=0, + xi=0.0, + random_state=None): + self._random_state = self._get_random_state(random_state) + self._utility_function = AcquisitionFunction(method=method, + beta=beta, + xi=xi, + beta_decay=beta_decay, + beta_decay_delay=beta_decay_delay) + self._gp = GaussianProcessRegressor( + kernel=Matern(nu=2.5), + alpha=1e-6, + normalize_y=True, + n_restarts_optimizer=5, + random_state=self._random_state + ) + + def _get_random_state(self, random_state=None): + """Get random state.""" + if random_state is not None and not isinstance(random_state, (int, np.random.RandomState)): + raise ParamValueError("The 'random_state' should be None, integer or np.random.RandomState.") + if not isinstance(random_state, np.random.RandomState): + random_state = np.random.RandomState(random_state) + return random_state + + def _acq_max(self, gp, y_max, bounds, params_info, n_warmup=10000, n_iter=10): + """Get max try calculated by acquisition function.""" + x_tries = generate_arrays(params_info, n_warmup) + ys = self._utility_function.ac(x_tries, gp=gp, y_max=y_max) + x_max = x_tries[ys.argmax()] + max_acq = ys.max() + + x_seeds = generate_arrays(params_info, n_iter) + for x_try in x_seeds: + res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), + x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") + + if not res.success: + continue + + if max_acq is None or -res.fun[0] >= max_acq: + x_max = match_value_type(x_max, params_info) + max_acq = -res.fun[0] + + return np.clip(x_max, bounds[:, 0], bounds[:, 1]) + + def suggest(self, params, target, params_info: dict): + """Get suggest values.""" + bounds = [] + for param_info in params_info.values(): + bound = param_info[HyperParamKey.BOUND.value] if HyperParamKey.BOUND.value in param_info \ + else param_info['choice'] + bounds.append([min(bound), max(bound)]) + bounds = np.array(bounds) + + min_lineage_rows = 2 + if not np.array(params).any() or params.shape[0] < min_lineage_rows: + suggestion = generate_arrays(params_info) + else: + self._gp.fit(params, target) + suggestion = self._acq_max( + gp=self._gp, + y_max=target.max(), + bounds=bounds, + params_info=params_info + ) + + suggestion = Transformer.transform_list_to_dict(params_info, suggestion) + return suggestion diff --git a/mindinsight/optimizer/utils/param_handler.py b/mindinsight/optimizer/utils/param_handler.py new file mode 100644 index 00000000..65736654 --- /dev/null +++ b/mindinsight/optimizer/utils/param_handler.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for params.""" +import numpy as np +from mindinsight.lineagemgr.model import LineageTable +from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType +from mindinsight.optimizer.common.log import logger + + +def generate_param(param_info, n=1): + """Generate param.""" + value = None + if HyperParamKey.BOUND.value in param_info: + bound = param_info[HyperParamKey.BOUND.value] + value = np.random.uniform(bound[0], bound[1], n) + if param_info[HyperParamKey.TYPE.value] == HyperParamType.INT.value: + value = value.astype(HyperParamType.INT.value) + if HyperParamKey.CHOICE.value in param_info: + indexes = np.random.randint(0, len(param_info[HyperParamKey.CHOICE.value]), n) + value = [param_info[HyperParamKey.CHOICE.value][index] for index in indexes] + if HyperParamKey.DECIMAL.value in param_info: + value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) + return np.array(value) + + +def generate_arrays(params_info: dict, n=1): + """Generate arrays.""" + suggest_params = None + for _, param_info in params_info.items(): + suggest_param = generate_param(param_info, n).reshape((-1, 1)) + if suggest_params is None: + suggest_params = suggest_param + else: + suggest_params = np.hstack((suggest_params, suggest_param)) + if n == 1: + return suggest_params[0] + return suggest_params + + +def match_value_type(array, params_info: dict): + """Make array match params type.""" + array_new = [] + index = 0 + for _, param_info in params_info.items(): + param_type = param_info[HyperParamKey.TYPE.value] + value = array[index] + if HyperParamKey.BOUND.value in param_info: + bound = param_info[HyperParamKey.BOUND.value] + value = max(bound[0], array[index]) + value = min(bound[1], value) + if HyperParamKey.CHOICE.value in param_info: + choices = param_info[HyperParamKey.CHOICE.value] + nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) + value = choices[nearest_index] + if param_type == HyperParamType.INT.value: + value = int(value) + if HyperParamKey.DECIMAL.value in param_info: + value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) + array_new.append(value) + index += 1 + return array_new + + +def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name): + """Organize params and target.""" + empty_result = np.array([]) + if lineage_table is None: + return empty_result, empty_result + + param_keys = list(params_info.keys()) + + lineage_df = lineage_table.dataframe_data + try: + lineage_df = lineage_df[param_keys + [target_name]] + lineage_df = lineage_df.dropna(axis=0, how='any') + return lineage_df[param_keys], lineage_df[target_name] + except KeyError as exc: + logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." + "Detail: %s.", str(exc)) + return empty_result, empty_result diff --git a/mindinsight/optimizer/utils/transformer.py b/mindinsight/optimizer/utils/transformer.py new file mode 100644 index 00000000..f15f68b4 --- /dev/null +++ b/mindinsight/optimizer/utils/transformer.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer.""" +from mindinsight.optimizer.utils.param_handler import match_value_type + + +class Transformer: + """Transformer.""" + @staticmethod + def transform_list_to_dict(params_info, suggest_list): + """Transform from tuner.""" + suggest_list = match_value_type(suggest_list, params_info) + param_dict = {} + for index, param_name in enumerate(params_info): + param_dict.update({param_name: suggest_list[index]}) + + return param_dict diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 49b34acd..480cde9f 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -96,3 +96,5 @@ class OptimizerErrors(Enum): """Enum definition for optimizer errors.""" SAMPLES_NOT_ENOUGH = 1 CORRELATION_NAN = 2 + HYPER_CONFIG_ERROR = 3 + OPTIMIZER_TERMINATE = 4 diff --git a/requirements.txt b/requirements.txt index 92c75c51..b6dce730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,9 @@ marshmallow>=2.19.2 numpy>=1.17.0 protobuf>=3.8.0 psutil>=5.6.1 +pyyaml>=5.3 +scipy>=1.4.1 +scikit-learn>=0.23.1 six>=1.12.0 Werkzeug>=1.0.0 pandas>=1.0.4