# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Optimizer API module.""" import json import pandas as pd from flask import Blueprint, jsonify, request from mindinsight.conf import settings from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable from mindinsight.optimizer.common.enums import ReasonCode from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError from mindinsight.optimizer.utils.importances import calc_hyper_param_importance from mindinsight.optimizer.utils.utils import calc_histogram from mindinsight.utils.exceptions import ParamValueError BLUEPRINT = Blueprint("optimizer", __name__, url_prefix=settings.URL_PATH_PREFIX+settings.API_PREFIX) @BLUEPRINT.route("/optimizer/targets/search", methods=["POST"]) def get_optimize_targets(): """Get optimize targets.""" search_condition = request.stream.read() try: search_condition = json.loads(search_condition if search_condition else "{}") except Exception: raise ParamValueError("Json data parse failed.") response = _get_optimize_targets(DATA_MANAGER, search_condition) return jsonify(response) def _get_optimize_targets(data_manager, search_condition=None): """Get optimize targets.""" flatten_lineage = get_flattened_lineage(data_manager, search_condition) table = LineageTable(pd.DataFrame(flatten_lineage)) target_summaries = [] for target in table.target_names: hyper_parameters = [] for hyper_param in table.hyper_param_names: param_info = {"name": hyper_param} try: importance = calc_hyper_param_importance(table.dataframe_data, hyper_param, target) param_info.update({"importance": importance}) except SamplesNotEnoughError: param_info.update({"importance": 0}) param_info.update({"reason_code": ReasonCode.SAMPLES_NOT_ENOUGH.value}) except CorrelationNanError: param_info.update({"importance": 0}) param_info.update({"reason_code": ReasonCode.CORRELATION_NAN.value}) hyper_parameters.append(param_info) hyper_parameters.sort(key=lambda hyper_param: hyper_param.get("importance"), reverse=True) target_summary = { "name": target, "buckets": calc_histogram(table.get_column(target)), "hyper_parameters": hyper_parameters, "data": table.get_column_values(target) } target_summaries.append(target_summary) target_summaries.sort(key=lambda summary: summary.get("name")) hyper_params_metadata = [{ "name": hyper_param, "data": table.get_column_values(hyper_param) } for hyper_param in table.hyper_param_names] result = { "metadata": { "train_ids": table.train_ids, "possible_hyper_parameters": hyper_params_metadata, "unrecognized_params": table.drop_column_info }, "targets": target_summaries } return result def init_module(app): """ Init module entry. Args: app: the application obj. """ app.register_blueprint(BLUEPRINT)