| @@ -14,12 +14,11 @@ | |||
| # ============================================================================ | |||
| """Optimizer API module.""" | |||
| import json | |||
| import pandas as pd | |||
| from flask import Blueprint, jsonify, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | |||
| from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable | |||
| from mindinsight.lineagemgr.model import get_lineage_table | |||
| from mindinsight.optimizer.common.enums import ReasonCode | |||
| from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError | |||
| from mindinsight.optimizer.utils.importances import calc_hyper_param_importance | |||
| @@ -44,8 +43,7 @@ def get_optimize_targets(): | |||
| def _get_optimize_targets(data_manager, search_condition=None): | |||
| """Get optimize targets.""" | |||
| flatten_lineage = get_flattened_lineage(data_manager, search_condition) | |||
| table = LineageTable(pd.DataFrame(flatten_lineage)) | |||
| table = get_lineage_table(data_manager, search_condition) | |||
| target_summaries = [] | |||
| for target in table.target_names: | |||
| @@ -179,8 +179,8 @@ class SearchModelConditionParameter(Schema): | |||
| raise ValidationError("Given lineage type should be one of %s." % lineage_types) | |||
| @pre_load | |||
| def check_comparision(self, data, **kwargs): | |||
| """Check comparision for all parameters in schema.""" | |||
| def check_comparison(self, data, **kwargs): | |||
| """Check comparison for all parameters in schema.""" | |||
| for attr, condition in data.items(): | |||
| if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']: | |||
| continue | |||
| @@ -27,8 +27,8 @@ from mindinsight.optimizer.common.enums import ReasonCode | |||
| from mindinsight.optimizer.utils.utils import is_simple_numpy_number | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| _METRIC_PREFIX = "[M]" | |||
| _USER_DEFINED_PREFIX = "[U]" | |||
| METRIC_PREFIX = "[M]" | |||
| USER_DEFINED_PREFIX = "[U]" | |||
| USER_DEFINED_INFO_LIMIT = 100 | |||
| @@ -85,7 +85,7 @@ def get_flattened_lineage(data_manager, search_condition=None): | |||
| for index, lineage in enumerate(lineages): | |||
| flatten_dict['train_id'].append(lineage.get("summary_dir")) | |||
| for key, val in _flatten_lineage(lineage.get('model_lineage', {})): | |||
| if key.startswith(_USER_DEFINED_PREFIX) and key not in flatten_dict: | |||
| if key.startswith(USER_DEFINED_PREFIX) and key not in flatten_dict: | |||
| if user_count > USER_DEFINED_INFO_LIMIT: | |||
| log.warning("The user_defined_info has reached the limit %s. %r is ignored", | |||
| USER_DEFINED_INFO_LIMIT, key) | |||
| @@ -105,10 +105,10 @@ def _flatten_lineage(lineage): | |||
| for key, val in lineage.items(): | |||
| if key == 'metric': | |||
| for k, v in val.items(): | |||
| yield f'{_METRIC_PREFIX}{k}', v | |||
| yield f'{METRIC_PREFIX}{k}', v | |||
| elif key == 'user_defined': | |||
| for k, v in val.items(): | |||
| yield f'{_USER_DEFINED_PREFIX}{k}', v | |||
| yield f'{USER_DEFINED_PREFIX}{k}', v | |||
| else: | |||
| yield key, val | |||
| @@ -144,7 +144,7 @@ class LineageTable: | |||
| self._df = self._df.drop(columns=columns_to_drop) | |||
| for name in columns_to_drop: | |||
| if not name.startswith(_USER_DEFINED_PREFIX): | |||
| if not name.startswith(USER_DEFINED_PREFIX): | |||
| continue | |||
| self._drop_columns_info.append({ | |||
| "name": name, | |||
| @@ -155,7 +155,7 @@ class LineageTable: | |||
| @property | |||
| def target_names(self): | |||
| """Get names for optimize targets (eg loss, accuracy).""" | |||
| target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)] | |||
| target_names = [name for name in self._df.columns if name.startswith(METRIC_PREFIX)] | |||
| if self._LOSS_NAME in self._df.columns: | |||
| target_names.append(self._LOSS_NAME) | |||
| return target_names | |||
| @@ -167,7 +167,7 @@ class LineageTable: | |||
| hyper_param_names = [ | |||
| name for name in self._df.columns | |||
| if not name.startswith(_METRIC_PREFIX) and name not in blocked_names] | |||
| if not name.startswith(METRIC_PREFIX) and name not in blocked_names] | |||
| if self._LOSS_NAME in hyper_param_names: | |||
| hyper_param_names.remove(self._LOSS_NAME) | |||
| @@ -184,7 +184,7 @@ class LineageTable: | |||
| @property | |||
| def user_defined_hyper_param_names(self): | |||
| """Get user defined hyper param names.""" | |||
| names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)] | |||
| names = [name for name in self._df.columns if name.startswith(USER_DEFINED_PREFIX)] | |||
| return names | |||
| def get_column(self, name): | |||
| @@ -220,3 +220,11 @@ class LineageTable: | |||
| def drop_column_info(self): | |||
| """Get dropped columns info.""" | |||
| return self._drop_columns_info | |||
| def get_lineage_table(data_manager, search_condition=None): | |||
| """Get lineage table from data_manager.""" | |||
| lineage_table = get_flattened_lineage(data_manager, search_condition) | |||
| lineage_table = LineageTable(pd.DataFrame(lineage_table)) | |||
| return lineage_table | |||
| @@ -17,3 +17,5 @@ Optimizer. | |||
| Optimizer provides optimization target distribution, parameter importance, etc. | |||
| """ | |||
| from mindinsight.optimizer.hyper_config import HyperConfig | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Common constants for optimizer.""" | |||
| HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" | |||
| @@ -30,3 +30,46 @@ class ReasonCode(BaseEnum): | |||
| NOT_ALL_NUMBERS = 1 | |||
| SAMPLES_NOT_ENOUGH = 2 | |||
| CORRELATION_NAN = 3 | |||
| class AcquisitionFunctionEnum(BaseEnum): | |||
| """Enum for acquisition function method.""" | |||
| # Upper confidence bound | |||
| UCB = 'ucb' | |||
| # Probability of improvement | |||
| PI = 'pi' | |||
| # Expected improvement | |||
| EI = 'ei' | |||
| class TuneMethod(BaseEnum): | |||
| """Enum for tuning method.""" | |||
| # Gaussian process regressor | |||
| GP = 'gp' | |||
| class HyperParamKey(BaseEnum): | |||
| """Config keys for hyper parameters.""" | |||
| BOUND = 'bounds' | |||
| CHOICE = 'choice' | |||
| DECIMAL = 'decimal' | |||
| TYPE = 'type' | |||
| class HyperParamType(BaseEnum): | |||
| """Config keys for hyper parameters.""" | |||
| INT = 'int' | |||
| FLOAT = 'float' | |||
| class TargetKey(BaseEnum): | |||
| """Config keys for target.""" | |||
| GROUP = 'group' | |||
| NAME = 'name' | |||
| GOAL = 'goal' | |||
| class TargetGoal(BaseEnum): | |||
| """Goal for target.""" | |||
| MAXIMUM = 'maximize' | |||
| MINIMUM = 'minimize' | |||
| @@ -32,3 +32,19 @@ class CorrelationNanError(MindInsightException): | |||
| super(CorrelationNanError, self).__init__(OptimizerErrors.CORRELATION_NAN, | |||
| error_msg, | |||
| http_code=400) | |||
| class HyperConfigError(MindInsightException): | |||
| """Hyper config error.""" | |||
| def __init__(self, error_msg="Hyper config is not correct."): | |||
| super(HyperConfigError, self).__init__(OptimizerErrors.HYPER_CONFIG_ERROR, | |||
| error_msg, | |||
| http_code=400) | |||
| class OptimizerTerminateError(MindInsightException): | |||
| """Hyper config error.""" | |||
| def __init__(self, error_msg="Auto tuning has been terminated."): | |||
| super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | |||
| error_msg, | |||
| http_code=400) | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Hyper config.""" | |||
| import json | |||
| import os | |||
| from attrdict import AttrDict | |||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | |||
| from mindinsight.optimizer.common.exceptions import HyperConfigError | |||
| _HYPER_CONFIG_LEN_LIMIT = 100000 | |||
| class HyperConfig: | |||
| """ | |||
| Hyper config. | |||
| Init hyper config: | |||
| >>> hyper_config = HyperConfig() | |||
| Get suggest params: | |||
| >>> param_obj = hyper_config.params | |||
| >>> learning_rate = params.learning_rate | |||
| Get summary dir: | |||
| >>> summary_dir = hyper_config.summary_dir | |||
| Record by SummaryCollector: | |||
| >>> summary_cb = SummaryCollector(summary_dir) | |||
| """ | |||
| def __init__(self): | |||
| self._init_validate_hyper_config() | |||
| def _init_validate_hyper_config(self): | |||
| """Init and validate hyper config.""" | |||
| hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) | |||
| if hyper_config is None: | |||
| raise HyperConfigError("Hyper config is not in system environment.") | |||
| if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT: | |||
| raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of " | |||
| "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||
| try: | |||
| hyper_config = json.loads(hyper_config) | |||
| except TypeError as exc: | |||
| raise HyperConfigError("Hyper config type error. detail: %s." % str(exc)) | |||
| except Exception as exc: | |||
| raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) | |||
| self._validate_hyper_config(hyper_config) | |||
| self._summary_dir = hyper_config.get('summary_dir') | |||
| self._param_obj = AttrDict(hyper_config.get('params')) | |||
| def _validate_hyper_config(self, hyper_config): | |||
| """Validate hyper config.""" | |||
| for key in ['summary_dir', 'params']: | |||
| if key not in hyper_config: | |||
| raise HyperConfigError("%r must exist in hyper_config." % key) | |||
| # validate summary_dir | |||
| summary_dir = hyper_config.get('summary_dir') | |||
| if not isinstance(summary_dir, str): | |||
| raise HyperConfigError("The 'summary_dir' should be string.") | |||
| hyper_config['summary_dir'] = os.path.realpath(summary_dir) | |||
| # validate params | |||
| params = hyper_config.get('params') | |||
| if not isinstance(params, dict): | |||
| raise HyperConfigError("'params' is not a dict.") | |||
| for key, value in params.items(): | |||
| if not isinstance(value, (int, float)): | |||
| raise HyperConfigError("The value of %r is not integer or float." % key) | |||
| @property | |||
| def params(self): | |||
| """Get params.""" | |||
| return self._param_obj | |||
| @property | |||
| def summary_dir(self): | |||
| """Get train summary dir path.""" | |||
| return self._summary_dir | |||
| @@ -0,0 +1,163 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """General tuner.""" | |||
| import json | |||
| import os | |||
| import shlex | |||
| import subprocess | |||
| import uuid | |||
| import yaml | |||
| from marshmallow import ValidationError | |||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | |||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||
| from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path | |||
| from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX | |||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | |||
| from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal | |||
| from mindinsight.optimizer.common.exceptions import OptimizerTerminateError | |||
| from mindinsight.optimizer.common.log import logger | |||
| from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner | |||
| from mindinsight.optimizer.utils.param_handler import organize_params_target | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError | |||
| _OK = 0 | |||
| class Tuner: | |||
| """ | |||
| Tuner for auto tuning. | |||
| Args: | |||
| config_path (str): config path, a yaml format file containing settings about tuner, target and parameters, etc. | |||
| Raises: | |||
| FileSystemPermissionError, can not open the config file because of permission. | |||
| UnknownError, other exception. | |||
| """ | |||
| def __init__(self, config_path: str): | |||
| self._config_info = self._validate_config(config_path) | |||
| self._summary_base_dir = self._config_info.get('summary_base_dir') | |||
| self._data_manager = self._init_data_manager() | |||
| self._dir_prefix = 'train' | |||
| def _validate_config(self, config_path): | |||
| """Check config_path.""" | |||
| config_path = self._normalize_path("config_path", config_path) | |||
| try: | |||
| with open(config_path, "r") as file: | |||
| config_info = yaml.safe_load(file) | |||
| except PermissionError as exc: | |||
| raise FileSystemPermissionError("Can not open config file. Detail: %s." % str(exc)) | |||
| except Exception as exc: | |||
| raise UnknownError("Detail: %s." % str(exc)) | |||
| # need to add validation for config_info: command, summary_base_dir, target and params. | |||
| config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) | |||
| self._make_summary_base_dir(config_info['summary_base_dir']) | |||
| return config_info | |||
| def _make_summary_base_dir(self, summary_base_dir): | |||
| """Check and make summary_base_dir.""" | |||
| if not os.path.exists(summary_base_dir): | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| mode = permissions << 6 | |||
| try: | |||
| logger.info("The summary_base_dir is generated automatically, path is %s.", summary_base_dir) | |||
| os.makedirs(summary_base_dir, mode=mode, exist_ok=True) | |||
| except OSError as exc: | |||
| raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) | |||
| def _init_data_manager(self): | |||
| """Initialize data_manager.""" | |||
| data_manager = DataManager(summary_base_dir=self._summary_base_dir) | |||
| data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||
| return data_manager | |||
| def _normalize_path(self, param_name, path): | |||
| """Normalize config path.""" | |||
| path = os.path.realpath(path) | |||
| try: | |||
| path = safe_normalize_path( | |||
| path, param_name, None, check_absolute_path=True | |||
| ) | |||
| except ValidationError: | |||
| logger.error("The %r is invalid.", param_name) | |||
| raise ParamValueError("The %r is invalid." % param_name) | |||
| return path | |||
| def _update_from_lineage(self): | |||
| """Update lineage from lineagemgr.""" | |||
| self._data_manager.start_load_data(reload_interval=0).join() | |||
| try: | |||
| lineage_table = get_lineage_table(self._data_manager) | |||
| except MindInsightException as err: | |||
| logger.info("Can not query lineage. Detail: %s", str(err)) | |||
| lineage_table = None | |||
| self._lineage_table = lineage_table | |||
| def optimize(self, max_expr_times=1): | |||
| """Method for auto tuning.""" | |||
| target_info = self._config_info.get('target') | |||
| params_info = self._config_info.get('parameters') | |||
| command = self._config_info.get('command') | |||
| tuner = self._config_info.get('tuner') | |||
| for _ in range(max_expr_times): | |||
| self._update_from_lineage() | |||
| suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name")) | |||
| hyper_config = { | |||
| 'params': suggestion, | |||
| 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}') | |||
| } | |||
| os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) | |||
| s = subprocess.Popen(shlex.split(command)) | |||
| s.wait() | |||
| if s.returncode != _OK: | |||
| logger.error("An error occurred during execution, the auto tuning will be terminated.") | |||
| raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") | |||
| def _get_tuner(self, tune_method=TuneMethod.GP.value): | |||
| """Get tuner.""" | |||
| if tune_method.lower() not in TuneMethod.list_members(): | |||
| raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) | |||
| # Only support gaussian process regressor currently. | |||
| return GPBaseTuner() | |||
| def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method): | |||
| """Get suggestions for targets.""" | |||
| tuner = self._get_tuner(method) | |||
| target_name = target_info[TargetKey.NAME.value] | |||
| if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric': | |||
| target_name = METRIC_PREFIX + target_name | |||
| param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name) | |||
| if not param_matrix.empty: | |||
| suggestion = tuner.suggest([], [], params_info) | |||
| else: | |||
| target_column = target_matrix[target_name].reshape((-1, 1)) | |||
| if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: | |||
| target_column = -target_column | |||
| suggestion = tuner.suggest(param_matrix, target_column, params_info) | |||
| return suggestion | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base tuner.""" | |||
| from abc import abstractmethod | |||
| class BaseTuner: | |||
| @abstractmethod | |||
| def suggest(self, *args): | |||
| """Suggest method should be implemented.""" | |||
| @@ -0,0 +1,178 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GP Tuner.""" | |||
| import warnings | |||
| import numpy as np | |||
| from scipy.stats import norm | |||
| from scipy.optimize import minimize | |||
| from sklearn.gaussian_process import GaussianProcessRegressor | |||
| from sklearn.gaussian_process.kernels import Matern | |||
| from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey | |||
| from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type | |||
| from mindinsight.optimizer.utils.transformer import Transformer | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.optimizer.tuners.base_tuner import BaseTuner | |||
| class AcquisitionFunction: | |||
| """ | |||
| It can be seen from the Gaussian process that the probability description of the objective | |||
| function can be obtained by sampling. Sampling usually involves two aspects: | |||
| - Explore: Explore new spaces, this sampling helps to estimate more accurate results; | |||
| - Exploit: Sampling near the existing results (usually near the existing maximum value), | |||
| hoping to find larger results. | |||
| The purpose of the acquisition function is to balance these two sampling processes. | |||
| Supported acquisition function: | |||
| - Probability of improvement. | |||
| - Expected improvement. | |||
| - Upper confidence bound. The weighted sum of posterior mean and posterior standard deviation. | |||
| formula: result = exploitation + βt * exploration, where βt are appropriate constants. | |||
| Args: | |||
| method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'. | |||
| beta (float): trade-off param for upper confidence bound function. | |||
| beta_decay (float): the decay for beta. Formula: beta = beta * beta_decay. | |||
| beta_decay_delay (int): if the counter is bigger than beta_decay_delay, the beta begins to decay. | |||
| xi (float): trade-off for expected improvement and probability of improvement. | |||
| """ | |||
| def __init__(self, method: str, beta, xi, beta_decay=1, beta_decay_delay=0): | |||
| self._beta = beta | |||
| self._beta_decay = beta_decay | |||
| self._beta_decay_delay = beta_decay_delay | |||
| self._xi = xi | |||
| self._method = method.lower() | |||
| if self._method not in AcquisitionFunctionEnum.list_members(): | |||
| raise ParamValueError(error_detail="The 'method' should be in %s." % AcquisitionFunctionEnum.list_members()) | |||
| self._counter = 0 | |||
| def update(self): | |||
| """Update k.""" | |||
| self._counter += 1 | |||
| if self._counter > self._beta_decay_delay and self._beta_decay < 1: | |||
| self._beta *= self._beta_decay | |||
| def ac(self, x, gp, y_max): | |||
| """Acquisition Function.""" | |||
| with warnings.catch_warnings(): | |||
| warnings.simplefilter("ignore") | |||
| mean, std = gp.predict(x, return_std=True) | |||
| if self._method == AcquisitionFunctionEnum.UCB.value: | |||
| # Upper confidence bound. | |||
| res = mean + self._beta * std | |||
| elif self._method == AcquisitionFunctionEnum.EI.value: | |||
| # Expected improvement. | |||
| u_f = (mean - y_max - self._xi) | |||
| z = u_f / std | |||
| res = u_f * norm.cdf(z) + std * norm.pdf(z) | |||
| else: | |||
| # Probability of improvement. | |||
| z = (mean - y_max - self._xi) / std | |||
| res = norm.cdf(z) | |||
| return res | |||
| class GPBaseTuner(BaseTuner): | |||
| """ | |||
| Tuner using gaussian process regressor. | |||
| Args: | |||
| method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'. | |||
| Detail at AcquisitionFunction. | |||
| beta (float): β, trade-off param for upper confidence bound function. | |||
| beta_decay (float): the decay for beta. beta = beta * beta_decay. | |||
| beta_decay_delay (int): if counter is bigger than beta_decay_delay, the beta begins to decay. | |||
| xi (float): ξ, trade-off for expected improvement and probability of improvement. | |||
| random_state (np.random.RandomState): if it is None, it will be assigned as RandomState. | |||
| """ | |||
| def __init__(self, | |||
| method=AcquisitionFunctionEnum.UCB.value, | |||
| beta=2.576, | |||
| beta_decay=1, | |||
| beta_decay_delay=0, | |||
| xi=0.0, | |||
| random_state=None): | |||
| self._random_state = self._get_random_state(random_state) | |||
| self._utility_function = AcquisitionFunction(method=method, | |||
| beta=beta, | |||
| xi=xi, | |||
| beta_decay=beta_decay, | |||
| beta_decay_delay=beta_decay_delay) | |||
| self._gp = GaussianProcessRegressor( | |||
| kernel=Matern(nu=2.5), | |||
| alpha=1e-6, | |||
| normalize_y=True, | |||
| n_restarts_optimizer=5, | |||
| random_state=self._random_state | |||
| ) | |||
| def _get_random_state(self, random_state=None): | |||
| """Get random state.""" | |||
| if random_state is not None and not isinstance(random_state, (int, np.random.RandomState)): | |||
| raise ParamValueError("The 'random_state' should be None, integer or np.random.RandomState.") | |||
| if not isinstance(random_state, np.random.RandomState): | |||
| random_state = np.random.RandomState(random_state) | |||
| return random_state | |||
| def _acq_max(self, gp, y_max, bounds, params_info, n_warmup=10000, n_iter=10): | |||
| """Get max try calculated by acquisition function.""" | |||
| x_tries = generate_arrays(params_info, n_warmup) | |||
| ys = self._utility_function.ac(x_tries, gp=gp, y_max=y_max) | |||
| x_max = x_tries[ys.argmax()] | |||
| max_acq = ys.max() | |||
| x_seeds = generate_arrays(params_info, n_iter) | |||
| for x_try in x_seeds: | |||
| res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), | |||
| x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") | |||
| if not res.success: | |||
| continue | |||
| if max_acq is None or -res.fun[0] >= max_acq: | |||
| x_max = match_value_type(x_max, params_info) | |||
| max_acq = -res.fun[0] | |||
| return np.clip(x_max, bounds[:, 0], bounds[:, 1]) | |||
| def suggest(self, params, target, params_info: dict): | |||
| """Get suggest values.""" | |||
| bounds = [] | |||
| for param_info in params_info.values(): | |||
| bound = param_info[HyperParamKey.BOUND.value] if HyperParamKey.BOUND.value in param_info \ | |||
| else param_info['choice'] | |||
| bounds.append([min(bound), max(bound)]) | |||
| bounds = np.array(bounds) | |||
| min_lineage_rows = 2 | |||
| if not np.array(params).any() or params.shape[0] < min_lineage_rows: | |||
| suggestion = generate_arrays(params_info) | |||
| else: | |||
| self._gp.fit(params, target) | |||
| suggestion = self._acq_max( | |||
| gp=self._gp, | |||
| y_max=target.max(), | |||
| bounds=bounds, | |||
| params_info=params_info | |||
| ) | |||
| suggestion = Transformer.transform_list_to_dict(params_info, suggestion) | |||
| return suggestion | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for params.""" | |||
| import numpy as np | |||
| from mindinsight.lineagemgr.model import LineageTable | |||
| from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType | |||
| from mindinsight.optimizer.common.log import logger | |||
| def generate_param(param_info, n=1): | |||
| """Generate param.""" | |||
| value = None | |||
| if HyperParamKey.BOUND.value in param_info: | |||
| bound = param_info[HyperParamKey.BOUND.value] | |||
| value = np.random.uniform(bound[0], bound[1], n) | |||
| if param_info[HyperParamKey.TYPE.value] == HyperParamType.INT.value: | |||
| value = value.astype(HyperParamType.INT.value) | |||
| if HyperParamKey.CHOICE.value in param_info: | |||
| indexes = np.random.randint(0, len(param_info[HyperParamKey.CHOICE.value]), n) | |||
| value = [param_info[HyperParamKey.CHOICE.value][index] for index in indexes] | |||
| if HyperParamKey.DECIMAL.value in param_info: | |||
| value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) | |||
| return np.array(value) | |||
| def generate_arrays(params_info: dict, n=1): | |||
| """Generate arrays.""" | |||
| suggest_params = None | |||
| for _, param_info in params_info.items(): | |||
| suggest_param = generate_param(param_info, n).reshape((-1, 1)) | |||
| if suggest_params is None: | |||
| suggest_params = suggest_param | |||
| else: | |||
| suggest_params = np.hstack((suggest_params, suggest_param)) | |||
| if n == 1: | |||
| return suggest_params[0] | |||
| return suggest_params | |||
| def match_value_type(array, params_info: dict): | |||
| """Make array match params type.""" | |||
| array_new = [] | |||
| index = 0 | |||
| for _, param_info in params_info.items(): | |||
| param_type = param_info[HyperParamKey.TYPE.value] | |||
| value = array[index] | |||
| if HyperParamKey.BOUND.value in param_info: | |||
| bound = param_info[HyperParamKey.BOUND.value] | |||
| value = max(bound[0], array[index]) | |||
| value = min(bound[1], value) | |||
| if HyperParamKey.CHOICE.value in param_info: | |||
| choices = param_info[HyperParamKey.CHOICE.value] | |||
| nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) | |||
| value = choices[nearest_index] | |||
| if param_type == HyperParamType.INT.value: | |||
| value = int(value) | |||
| if HyperParamKey.DECIMAL.value in param_info: | |||
| value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) | |||
| array_new.append(value) | |||
| index += 1 | |||
| return array_new | |||
| def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name): | |||
| """Organize params and target.""" | |||
| empty_result = np.array([]) | |||
| if lineage_table is None: | |||
| return empty_result, empty_result | |||
| param_keys = list(params_info.keys()) | |||
| lineage_df = lineage_table.dataframe_data | |||
| try: | |||
| lineage_df = lineage_df[param_keys + [target_name]] | |||
| lineage_df = lineage_df.dropna(axis=0, how='any') | |||
| return lineage_df[param_keys], lineage_df[target_name] | |||
| except KeyError as exc: | |||
| logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." | |||
| "Detail: %s.", str(exc)) | |||
| return empty_result, empty_result | |||
| @@ -0,0 +1,29 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Transformer.""" | |||
| from mindinsight.optimizer.utils.param_handler import match_value_type | |||
| class Transformer: | |||
| """Transformer.""" | |||
| @staticmethod | |||
| def transform_list_to_dict(params_info, suggest_list): | |||
| """Transform from tuner.""" | |||
| suggest_list = match_value_type(suggest_list, params_info) | |||
| param_dict = {} | |||
| for index, param_name in enumerate(params_info): | |||
| param_dict.update({param_name: suggest_list[index]}) | |||
| return param_dict | |||
| @@ -96,3 +96,5 @@ class OptimizerErrors(Enum): | |||
| """Enum definition for optimizer errors.""" | |||
| SAMPLES_NOT_ENOUGH = 1 | |||
| CORRELATION_NAN = 2 | |||
| HYPER_CONFIG_ERROR = 3 | |||
| OPTIMIZER_TERMINATE = 4 | |||
| @@ -10,6 +10,9 @@ marshmallow>=2.19.2 | |||
| numpy>=1.17.0 | |||
| protobuf>=3.8.0 | |||
| psutil>=5.6.1 | |||
| pyyaml>=5.3 | |||
| scipy>=1.4.1 | |||
| scikit-learn>=0.23.1 | |||
| six>=1.12.0 | |||
| Werkzeug>=1.0.0 | |||
| pandas>=1.0.4 | |||