Browse Source

add auto-tune framework, add validation for hyper_config

tags/v1.1.0
luopengting 5 years ago
parent
commit
fb4c150e56
15 changed files with 681 additions and 15 deletions
  1. +2
    -4
      mindinsight/backend/optimizer/optimizer_api.py
  2. +2
    -2
      mindinsight/lineagemgr/common/validator/model_parameter.py
  3. +17
    -9
      mindinsight/lineagemgr/model.py
  4. +2
    -0
      mindinsight/optimizer/__init__.py
  5. +16
    -0
      mindinsight/optimizer/common/constants.py
  6. +43
    -0
      mindinsight/optimizer/common/enums.py
  7. +16
    -0
      mindinsight/optimizer/common/exceptions.py
  8. +94
    -0
      mindinsight/optimizer/hyper_config.py
  9. +163
    -0
      mindinsight/optimizer/tuner.py
  10. +22
    -0
      mindinsight/optimizer/tuners/base_tuner.py
  11. +178
    -0
      mindinsight/optimizer/tuners/gp_tuner.py
  12. +92
    -0
      mindinsight/optimizer/utils/param_handler.py
  13. +29
    -0
      mindinsight/optimizer/utils/transformer.py
  14. +2
    -0
      mindinsight/utils/constant.py
  15. +3
    -0
      requirements.txt

+ 2
- 4
mindinsight/backend/optimizer/optimizer_api.py View File

@@ -14,12 +14,11 @@
# ============================================================================
"""Optimizer API module."""
import json
import pandas as pd
from flask import Blueprint, jsonify, request

from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable
from mindinsight.lineagemgr.model import get_lineage_table
from mindinsight.optimizer.common.enums import ReasonCode
from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
@@ -44,8 +43,7 @@ def get_optimize_targets():

def _get_optimize_targets(data_manager, search_condition=None):
"""Get optimize targets."""
flatten_lineage = get_flattened_lineage(data_manager, search_condition)
table = LineageTable(pd.DataFrame(flatten_lineage))
table = get_lineage_table(data_manager, search_condition)

target_summaries = []
for target in table.target_names:


+ 2
- 2
mindinsight/lineagemgr/common/validator/model_parameter.py View File

@@ -179,8 +179,8 @@ class SearchModelConditionParameter(Schema):
raise ValidationError("Given lineage type should be one of %s." % lineage_types)

@pre_load
def check_comparision(self, data, **kwargs):
"""Check comparision for all parameters in schema."""
def check_comparison(self, data, **kwargs):
"""Check comparison for all parameters in schema."""
for attr, condition in data.items():
if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']:
continue


+ 17
- 9
mindinsight/lineagemgr/model.py View File

@@ -27,8 +27,8 @@ from mindinsight.optimizer.common.enums import ReasonCode
from mindinsight.optimizer.utils.utils import is_simple_numpy_number
from mindinsight.utils.exceptions import MindInsightException

_METRIC_PREFIX = "[M]"
_USER_DEFINED_PREFIX = "[U]"
METRIC_PREFIX = "[M]"
USER_DEFINED_PREFIX = "[U]"

USER_DEFINED_INFO_LIMIT = 100

@@ -85,7 +85,7 @@ def get_flattened_lineage(data_manager, search_condition=None):
for index, lineage in enumerate(lineages):
flatten_dict['train_id'].append(lineage.get("summary_dir"))
for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
if key.startswith(_USER_DEFINED_PREFIX) and key not in flatten_dict:
if key.startswith(USER_DEFINED_PREFIX) and key not in flatten_dict:
if user_count > USER_DEFINED_INFO_LIMIT:
log.warning("The user_defined_info has reached the limit %s. %r is ignored",
USER_DEFINED_INFO_LIMIT, key)
@@ -105,10 +105,10 @@ def _flatten_lineage(lineage):
for key, val in lineage.items():
if key == 'metric':
for k, v in val.items():
yield f'{_METRIC_PREFIX}{k}', v
yield f'{METRIC_PREFIX}{k}', v
elif key == 'user_defined':
for k, v in val.items():
yield f'{_USER_DEFINED_PREFIX}{k}', v
yield f'{USER_DEFINED_PREFIX}{k}', v
else:
yield key, val

@@ -144,7 +144,7 @@ class LineageTable:
self._df = self._df.drop(columns=columns_to_drop)

for name in columns_to_drop:
if not name.startswith(_USER_DEFINED_PREFIX):
if not name.startswith(USER_DEFINED_PREFIX):
continue
self._drop_columns_info.append({
"name": name,
@@ -155,7 +155,7 @@ class LineageTable:
@property
def target_names(self):
"""Get names for optimize targets (eg loss, accuracy)."""
target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)]
target_names = [name for name in self._df.columns if name.startswith(METRIC_PREFIX)]
if self._LOSS_NAME in self._df.columns:
target_names.append(self._LOSS_NAME)
return target_names
@@ -167,7 +167,7 @@ class LineageTable:

hyper_param_names = [
name for name in self._df.columns
if not name.startswith(_METRIC_PREFIX) and name not in blocked_names]
if not name.startswith(METRIC_PREFIX) and name not in blocked_names]

if self._LOSS_NAME in hyper_param_names:
hyper_param_names.remove(self._LOSS_NAME)
@@ -184,7 +184,7 @@ class LineageTable:
@property
def user_defined_hyper_param_names(self):
"""Get user defined hyper param names."""
names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)]
names = [name for name in self._df.columns if name.startswith(USER_DEFINED_PREFIX)]
return names

def get_column(self, name):
@@ -220,3 +220,11 @@ class LineageTable:
def drop_column_info(self):
"""Get dropped columns info."""
return self._drop_columns_info


def get_lineage_table(data_manager, search_condition=None):
"""Get lineage table from data_manager."""
lineage_table = get_flattened_lineage(data_manager, search_condition)
lineage_table = LineageTable(pd.DataFrame(lineage_table))

return lineage_table

+ 2
- 0
mindinsight/optimizer/__init__.py View File

@@ -17,3 +17,5 @@ Optimizer.

Optimizer provides optimization target distribution, parameter importance, etc.
"""

from mindinsight.optimizer.hyper_config import HyperConfig

+ 16
- 0
mindinsight/optimizer/common/constants.py View File

@@ -0,0 +1,16 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common constants for optimizer."""
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"

+ 43
- 0
mindinsight/optimizer/common/enums.py View File

@@ -30,3 +30,46 @@ class ReasonCode(BaseEnum):
NOT_ALL_NUMBERS = 1
SAMPLES_NOT_ENOUGH = 2
CORRELATION_NAN = 3


class AcquisitionFunctionEnum(BaseEnum):
"""Enum for acquisition function method."""
# Upper confidence bound
UCB = 'ucb'
# Probability of improvement
PI = 'pi'
# Expected improvement
EI = 'ei'


class TuneMethod(BaseEnum):
"""Enum for tuning method."""
# Gaussian process regressor
GP = 'gp'


class HyperParamKey(BaseEnum):
"""Config keys for hyper parameters."""
BOUND = 'bounds'
CHOICE = 'choice'
DECIMAL = 'decimal'
TYPE = 'type'


class HyperParamType(BaseEnum):
"""Config keys for hyper parameters."""
INT = 'int'
FLOAT = 'float'


class TargetKey(BaseEnum):
"""Config keys for target."""
GROUP = 'group'
NAME = 'name'
GOAL = 'goal'


class TargetGoal(BaseEnum):
"""Goal for target."""
MAXIMUM = 'maximize'
MINIMUM = 'minimize'

+ 16
- 0
mindinsight/optimizer/common/exceptions.py View File

@@ -32,3 +32,19 @@ class CorrelationNanError(MindInsightException):
super(CorrelationNanError, self).__init__(OptimizerErrors.CORRELATION_NAN,
error_msg,
http_code=400)


class HyperConfigError(MindInsightException):
"""Hyper config error."""
def __init__(self, error_msg="Hyper config is not correct."):
super(HyperConfigError, self).__init__(OptimizerErrors.HYPER_CONFIG_ERROR,
error_msg,
http_code=400)


class OptimizerTerminateError(MindInsightException):
"""Hyper config error."""
def __init__(self, error_msg="Auto tuning has been terminated."):
super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE,
error_msg,
http_code=400)

+ 94
- 0
mindinsight/optimizer/hyper_config.py View File

@@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hyper config."""
import json
import os
from attrdict import AttrDict

from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
from mindinsight.optimizer.common.exceptions import HyperConfigError

_HYPER_CONFIG_LEN_LIMIT = 100000


class HyperConfig:
"""
Hyper config.

Init hyper config:
>>> hyper_config = HyperConfig()

Get suggest params:
>>> param_obj = hyper_config.params
>>> learning_rate = params.learning_rate

Get summary dir:
>>> summary_dir = hyper_config.summary_dir

Record by SummaryCollector:
>>> summary_cb = SummaryCollector(summary_dir)
"""
def __init__(self):
self._init_validate_hyper_config()

def _init_validate_hyper_config(self):
"""Init and validate hyper config."""
hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
if hyper_config is None:
raise HyperConfigError("Hyper config is not in system environment.")
if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT:
raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of "
"hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))

try:
hyper_config = json.loads(hyper_config)
except TypeError as exc:
raise HyperConfigError("Hyper config type error. detail: %s." % str(exc))
except Exception as exc:
raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc))

self._validate_hyper_config(hyper_config)
self._summary_dir = hyper_config.get('summary_dir')
self._param_obj = AttrDict(hyper_config.get('params'))

def _validate_hyper_config(self, hyper_config):
"""Validate hyper config."""
for key in ['summary_dir', 'params']:
if key not in hyper_config:
raise HyperConfigError("%r must exist in hyper_config." % key)

# validate summary_dir
summary_dir = hyper_config.get('summary_dir')
if not isinstance(summary_dir, str):
raise HyperConfigError("The 'summary_dir' should be string.")
hyper_config['summary_dir'] = os.path.realpath(summary_dir)

# validate params
params = hyper_config.get('params')
if not isinstance(params, dict):
raise HyperConfigError("'params' is not a dict.")
for key, value in params.items():
if not isinstance(value, (int, float)):
raise HyperConfigError("The value of %r is not integer or float." % key)

@property
def params(self):
"""Get params."""
return self._param_obj

@property
def summary_dir(self):
"""Get train summary dir path."""
return self._summary_dir

+ 163
- 0
mindinsight/optimizer/tuner.py View File

@@ -0,0 +1,163 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""General tuner."""
import json
import os
import shlex
import subprocess
import uuid
import yaml

from marshmallow import ValidationError

from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX
from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal
from mindinsight.optimizer.common.exceptions import OptimizerTerminateError
from mindinsight.optimizer.common.log import logger
from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner
from mindinsight.optimizer.utils.param_handler import organize_params_target
from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError

_OK = 0


class Tuner:
"""
Tuner for auto tuning.

Args:
config_path (str): config path, a yaml format file containing settings about tuner, target and parameters, etc.

Raises:
FileSystemPermissionError, can not open the config file because of permission.
UnknownError, other exception.
"""
def __init__(self, config_path: str):
self._config_info = self._validate_config(config_path)
self._summary_base_dir = self._config_info.get('summary_base_dir')
self._data_manager = self._init_data_manager()
self._dir_prefix = 'train'

def _validate_config(self, config_path):
"""Check config_path."""
config_path = self._normalize_path("config_path", config_path)
try:
with open(config_path, "r") as file:
config_info = yaml.safe_load(file)
except PermissionError as exc:
raise FileSystemPermissionError("Can not open config file. Detail: %s." % str(exc))
except Exception as exc:
raise UnknownError("Detail: %s." % str(exc))

# need to add validation for config_info: command, summary_base_dir, target and params.
config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir'))
self._make_summary_base_dir(config_info['summary_base_dir'])
return config_info

def _make_summary_base_dir(self, summary_base_dir):
"""Check and make summary_base_dir."""
if not os.path.exists(summary_base_dir):
permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
mode = permissions << 6
try:
logger.info("The summary_base_dir is generated automatically, path is %s.", summary_base_dir)
os.makedirs(summary_base_dir, mode=mode, exist_ok=True)
except OSError as exc:
raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc))

def _init_data_manager(self):
"""Initialize data_manager."""
data_manager = DataManager(summary_base_dir=self._summary_base_dir)
data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater())

return data_manager

def _normalize_path(self, param_name, path):
"""Normalize config path."""
path = os.path.realpath(path)
try:
path = safe_normalize_path(
path, param_name, None, check_absolute_path=True
)
except ValidationError:
logger.error("The %r is invalid.", param_name)
raise ParamValueError("The %r is invalid." % param_name)

return path

def _update_from_lineage(self):
"""Update lineage from lineagemgr."""
self._data_manager.start_load_data(reload_interval=0).join()

try:
lineage_table = get_lineage_table(self._data_manager)
except MindInsightException as err:
logger.info("Can not query lineage. Detail: %s", str(err))
lineage_table = None

self._lineage_table = lineage_table

def optimize(self, max_expr_times=1):
"""Method for auto tuning."""
target_info = self._config_info.get('target')
params_info = self._config_info.get('parameters')
command = self._config_info.get('command')
tuner = self._config_info.get('tuner')
for _ in range(max_expr_times):
self._update_from_lineage()
suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name"))

hyper_config = {
'params': suggestion,
'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}')
}
os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config)
s = subprocess.Popen(shlex.split(command))
s.wait()
if s.returncode != _OK:
logger.error("An error occurred during execution, the auto tuning will be terminated.")
raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.")

def _get_tuner(self, tune_method=TuneMethod.GP.value):
"""Get tuner."""
if tune_method.lower() not in TuneMethod.list_members():
raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members())

# Only support gaussian process regressor currently.
return GPBaseTuner()

def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method):
"""Get suggestions for targets."""
tuner = self._get_tuner(method)
target_name = target_info[TargetKey.NAME.value]
if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric':
target_name = METRIC_PREFIX + target_name
param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name)

if not param_matrix.empty:
suggestion = tuner.suggest([], [], params_info)
else:
target_column = target_matrix[target_name].reshape((-1, 1))
if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value:
target_column = -target_column

suggestion = tuner.suggest(param_matrix, target_column, params_info)

return suggestion

+ 22
- 0
mindinsight/optimizer/tuners/base_tuner.py View File

@@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base tuner."""
from abc import abstractmethod


class BaseTuner:
@abstractmethod
def suggest(self, *args):
"""Suggest method should be implemented."""

+ 178
- 0
mindinsight/optimizer/tuners/gp_tuner.py View File

@@ -0,0 +1,178 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GP Tuner."""
import warnings
import numpy as np

from scipy.stats import norm
from scipy.optimize import minimize
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern

from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey
from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type
from mindinsight.optimizer.utils.transformer import Transformer
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.optimizer.tuners.base_tuner import BaseTuner


class AcquisitionFunction:
"""
It can be seen from the Gaussian process that the probability description of the objective
function can be obtained by sampling. Sampling usually involves two aspects:
- Explore: Explore new spaces, this sampling helps to estimate more accurate results;
- Exploit: Sampling near the existing results (usually near the existing maximum value),
hoping to find larger results.

The purpose of the acquisition function is to balance these two sampling processes.
Supported acquisition function:
- Probability of improvement.
- Expected improvement.
- Upper confidence bound. The weighted sum of posterior mean and posterior standard deviation.
formula: result = exploitation + βt * exploration, where βt are appropriate constants.

Args:
method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'.
beta (float): trade-off param for upper confidence bound function.
beta_decay (float): the decay for beta. Formula: beta = beta * beta_decay.
beta_decay_delay (int): if the counter is bigger than beta_decay_delay, the beta begins to decay.
xi (float): trade-off for expected improvement and probability of improvement.
"""
def __init__(self, method: str, beta, xi, beta_decay=1, beta_decay_delay=0):
self._beta = beta
self._beta_decay = beta_decay
self._beta_decay_delay = beta_decay_delay
self._xi = xi
self._method = method.lower()
if self._method not in AcquisitionFunctionEnum.list_members():
raise ParamValueError(error_detail="The 'method' should be in %s." % AcquisitionFunctionEnum.list_members())

self._counter = 0

def update(self):
"""Update k."""
self._counter += 1

if self._counter > self._beta_decay_delay and self._beta_decay < 1:
self._beta *= self._beta_decay

def ac(self, x, gp, y_max):
"""Acquisition Function."""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mean, std = gp.predict(x, return_std=True)

if self._method == AcquisitionFunctionEnum.UCB.value:
# Upper confidence bound.
res = mean + self._beta * std
elif self._method == AcquisitionFunctionEnum.EI.value:
# Expected improvement.
u_f = (mean - y_max - self._xi)
z = u_f / std
res = u_f * norm.cdf(z) + std * norm.pdf(z)
else:
# Probability of improvement.
z = (mean - y_max - self._xi) / std
res = norm.cdf(z)

return res


class GPBaseTuner(BaseTuner):
"""
Tuner using gaussian process regressor.

Args:
method (str): The method for acquisition function, including 'ucb', 'pi', and 'ei'.
Detail at AcquisitionFunction.
beta (float): β, trade-off param for upper confidence bound function.
beta_decay (float): the decay for beta. beta = beta * beta_decay.
beta_decay_delay (int): if counter is bigger than beta_decay_delay, the beta begins to decay.
xi (float): ξ, trade-off for expected improvement and probability of improvement.
random_state (np.random.RandomState): if it is None, it will be assigned as RandomState.
"""
def __init__(self,
method=AcquisitionFunctionEnum.UCB.value,
beta=2.576,
beta_decay=1,
beta_decay_delay=0,
xi=0.0,
random_state=None):
self._random_state = self._get_random_state(random_state)
self._utility_function = AcquisitionFunction(method=method,
beta=beta,
xi=xi,
beta_decay=beta_decay,
beta_decay_delay=beta_decay_delay)
self._gp = GaussianProcessRegressor(
kernel=Matern(nu=2.5),
alpha=1e-6,
normalize_y=True,
n_restarts_optimizer=5,
random_state=self._random_state
)

def _get_random_state(self, random_state=None):
"""Get random state."""
if random_state is not None and not isinstance(random_state, (int, np.random.RandomState)):
raise ParamValueError("The 'random_state' should be None, integer or np.random.RandomState.")
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)
return random_state

def _acq_max(self, gp, y_max, bounds, params_info, n_warmup=10000, n_iter=10):
"""Get max try calculated by acquisition function."""
x_tries = generate_arrays(params_info, n_warmup)
ys = self._utility_function.ac(x_tries, gp=gp, y_max=y_max)
x_max = x_tries[ys.argmax()]
max_acq = ys.max()

x_seeds = generate_arrays(params_info, n_iter)
for x_try in x_seeds:
res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max),
x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B")

if not res.success:
continue

if max_acq is None or -res.fun[0] >= max_acq:
x_max = match_value_type(x_max, params_info)
max_acq = -res.fun[0]

return np.clip(x_max, bounds[:, 0], bounds[:, 1])

def suggest(self, params, target, params_info: dict):
"""Get suggest values."""
bounds = []
for param_info in params_info.values():
bound = param_info[HyperParamKey.BOUND.value] if HyperParamKey.BOUND.value in param_info \
else param_info['choice']
bounds.append([min(bound), max(bound)])
bounds = np.array(bounds)

min_lineage_rows = 2
if not np.array(params).any() or params.shape[0] < min_lineage_rows:
suggestion = generate_arrays(params_info)
else:
self._gp.fit(params, target)
suggestion = self._acq_max(
gp=self._gp,
y_max=target.max(),
bounds=bounds,
params_info=params_info
)

suggestion = Transformer.transform_list_to_dict(params_info, suggestion)
return suggestion

+ 92
- 0
mindinsight/optimizer/utils/param_handler.py View File

@@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for params."""
import numpy as np
from mindinsight.lineagemgr.model import LineageTable
from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType
from mindinsight.optimizer.common.log import logger


def generate_param(param_info, n=1):
"""Generate param."""
value = None
if HyperParamKey.BOUND.value in param_info:
bound = param_info[HyperParamKey.BOUND.value]
value = np.random.uniform(bound[0], bound[1], n)
if param_info[HyperParamKey.TYPE.value] == HyperParamType.INT.value:
value = value.astype(HyperParamType.INT.value)
if HyperParamKey.CHOICE.value in param_info:
indexes = np.random.randint(0, len(param_info[HyperParamKey.CHOICE.value]), n)
value = [param_info[HyperParamKey.CHOICE.value][index] for index in indexes]
if HyperParamKey.DECIMAL.value in param_info:
value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value])
return np.array(value)


def generate_arrays(params_info: dict, n=1):
"""Generate arrays."""
suggest_params = None
for _, param_info in params_info.items():
suggest_param = generate_param(param_info, n).reshape((-1, 1))
if suggest_params is None:
suggest_params = suggest_param
else:
suggest_params = np.hstack((suggest_params, suggest_param))
if n == 1:
return suggest_params[0]
return suggest_params


def match_value_type(array, params_info: dict):
"""Make array match params type."""
array_new = []
index = 0
for _, param_info in params_info.items():
param_type = param_info[HyperParamKey.TYPE.value]
value = array[index]
if HyperParamKey.BOUND.value in param_info:
bound = param_info[HyperParamKey.BOUND.value]
value = max(bound[0], array[index])
value = min(bound[1], value)
if HyperParamKey.CHOICE.value in param_info:
choices = param_info[HyperParamKey.CHOICE.value]
nearest_index = int(np.argmin(np.fabs(np.array(choices) - value)))
value = choices[nearest_index]
if param_type == HyperParamType.INT.value:
value = int(value)
if HyperParamKey.DECIMAL.value in param_info:
value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value])
array_new.append(value)
index += 1
return array_new


def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name):
"""Organize params and target."""
empty_result = np.array([])
if lineage_table is None:
return empty_result, empty_result

param_keys = list(params_info.keys())

lineage_df = lineage_table.dataframe_data
try:
lineage_df = lineage_df[param_keys + [target_name]]
lineage_df = lineage_df.dropna(axis=0, how='any')
return lineage_df[param_keys], lineage_df[target_name]
except KeyError as exc:
logger.warning("Some keys not exist in specified params or target. It will suggest params randomly."
"Detail: %s.", str(exc))
return empty_result, empty_result

+ 29
- 0
mindinsight/optimizer/utils/transformer.py View File

@@ -0,0 +1,29 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer."""
from mindinsight.optimizer.utils.param_handler import match_value_type


class Transformer:
"""Transformer."""
@staticmethod
def transform_list_to_dict(params_info, suggest_list):
"""Transform from tuner."""
suggest_list = match_value_type(suggest_list, params_info)
param_dict = {}
for index, param_name in enumerate(params_info):
param_dict.update({param_name: suggest_list[index]})

return param_dict

+ 2
- 0
mindinsight/utils/constant.py View File

@@ -96,3 +96,5 @@ class OptimizerErrors(Enum):
"""Enum definition for optimizer errors."""
SAMPLES_NOT_ENOUGH = 1
CORRELATION_NAN = 2
HYPER_CONFIG_ERROR = 3
OPTIMIZER_TERMINATE = 4

+ 3
- 0
requirements.txt View File

@@ -10,6 +10,9 @@ marshmallow>=2.19.2
numpy>=1.17.0
protobuf>=3.8.0
psutil>=5.6.1
pyyaml>=5.3
scipy>=1.4.1
scikit-learn>=0.23.1
six>=1.12.0
Werkzeug>=1.0.0
pandas>=1.0.4


Loading…
Cancel
Save