1. add functions for calculating target buckets and params importances 2. add restful api for target buckets, importances and metadatatags/v1.0.0
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Optimizer API module.""" | |||
| from mindinsight.backend.optimizer.optimizer_api import init_module as init_optimizer_model | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app: Flask. A Flask instance. | |||
| """ | |||
| init_optimizer_model(app) | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Optimizer API module.""" | |||
| import json | |||
| from flask import Blueprint, jsonify, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | |||
| from mindinsight.lineagemgr.model import get_lineage_table | |||
| from mindinsight.optimizer.common.enums import ReasonCode | |||
| from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError | |||
| from mindinsight.optimizer.utils.importances import calc_hyper_param_importance | |||
| from mindinsight.optimizer.utils.utils import calc_histogram | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| BLUEPRINT = Blueprint("optimizer", __name__, url_prefix=settings.URL_PATH_PREFIX+settings.API_PREFIX) | |||
| @BLUEPRINT.route("/optimizer/targets/search", methods=["POST"]) | |||
| def get_optimize_targets(): | |||
| """Get optimize targets.""" | |||
| search_condition = request.stream.read() | |||
| try: | |||
| search_condition = json.loads(search_condition if search_condition else "{}") | |||
| except Exception: | |||
| raise ParamValueError("Json data parse failed.") | |||
| response = _get_optimize_targets(DATA_MANAGER, search_condition) | |||
| return jsonify(response) | |||
| def _get_optimize_targets(data_manager, search_condition): | |||
| """Get optimize targets.""" | |||
| table = get_lineage_table(data_manager, search_condition) | |||
| target_summaries = [] | |||
| for target in table.target_names: | |||
| hyper_parameters = [] | |||
| for hyper_param in table.hyper_param_names: | |||
| param_info = {"name": hyper_param} | |||
| try: | |||
| importance = calc_hyper_param_importance(table.df, hyper_param, target) | |||
| param_info.update({"importance": importance}) | |||
| except SamplesNotEnoughError: | |||
| param_info.update({"importance": 0}) | |||
| param_info.update({"reason_code": ReasonCode.SAMPLES_NOT_ENOUGH.value}) | |||
| except CorrelationNanError: | |||
| param_info.update({"importance": 0}) | |||
| param_info.update({"reason_code": ReasonCode.CORRELATION_NAN.value}) | |||
| hyper_parameters.append(param_info) | |||
| hyper_parameters.sort(key=lambda hyper_param: hyper_param.get("importance"), reverse=True) | |||
| target_summary = { | |||
| "name": target, | |||
| "buckets": calc_histogram(table.get_column(target)), | |||
| "hyper_parameters": hyper_parameters, | |||
| "data": table.get_column_values(target) | |||
| } | |||
| target_summaries.append(target_summary) | |||
| target_summaries.sort(key=lambda summary: summary.get("name")) | |||
| hyper_params_metadata = [{ | |||
| "name": hyper_param, | |||
| "data": table.get_column_values(hyper_param) | |||
| } for hyper_param in table.hyper_param_names] | |||
| result = { | |||
| "metadata": { | |||
| "train_ids": table.train_ids, | |||
| "possible_hyper_parameters": hyper_params_metadata, | |||
| "unrecognized_params": table.drop_column_info | |||
| }, | |||
| "targets": target_summaries | |||
| } | |||
| return result | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app: the application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| @@ -16,6 +16,7 @@ | |||
| import os | |||
| import re | |||
| from functools import wraps | |||
| from pathlib import Path | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \ | |||
| @@ -99,7 +100,29 @@ def make_directory(path): | |||
| try: | |||
| os.makedirs(path, exist_ok=True) | |||
| real_path = path | |||
| except PermissionError as e: | |||
| log.error("No write permission on the directory(%r), error = %r", path, e) | |||
| except PermissionError as err: | |||
| log.error("No write permission on the directory(%r), error = %r", path, err) | |||
| raise LineageParamTypeError("No write permission on the directory.") | |||
| return real_path | |||
| def get_relative_path(path, base_path): | |||
| """ | |||
| Get relative path based on base_path. | |||
| Args: | |||
| path (str): absolute path. | |||
| base_path: absolute base path. | |||
| Returns: | |||
| str, relative path based on base_path. | |||
| """ | |||
| try: | |||
| r_path = str(Path(path).relative_to(Path(base_path))) | |||
| except ValueError: | |||
| raise LineageParamValueError("The path %r does not start with %r." % (path, base_path)) | |||
| if r_path == ".": | |||
| r_path = "" | |||
| return os.path.join("./", r_path) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """This file is used to define the model lineage python api.""" | |||
| import os | |||
| import numpy as np | |||
| import pandas as pd | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValueError, \ | |||
| @@ -21,17 +22,21 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValu | |||
| LineageQuerierParamException, LineageDirNotExistError, LineageSearchConditionParamError, \ | |||
| LineageParamTypeError, LineageSummaryParseException | |||
| from mindinsight.lineagemgr.common.log import logger as log | |||
| from mindinsight.lineagemgr.common.utils import normalize_summary_dir | |||
| from mindinsight.lineagemgr.common.utils import normalize_summary_dir, get_relative_path | |||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | |||
| from mindinsight.lineagemgr.common.validator.validate import validate_filter_key, validate_search_model_condition, \ | |||
| validate_condition, validate_path, validate_train_id | |||
| from mindinsight.lineagemgr.lineage_parser import LineageParser, LineageOrganizer | |||
| from mindinsight.lineagemgr.querier.querier import Querier | |||
| from mindinsight.optimizer.common.enums import ReasonCode | |||
| from mindinsight.optimizer.utils.utils import is_simple_numpy_number | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| _METRIC_PREFIX = "[M]" | |||
| _USER_DEFINED_PREFIX = "[U]" | |||
| USER_DEFINED_INFO_LIMIT = 100 | |||
| def get_summary_lineage(data_manager=None, summary_dir=None, keys=None): | |||
| """ | |||
| @@ -189,44 +194,181 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition): | |||
| return search_condition | |||
| def get_lineage_table(data_manager): | |||
| def get_lineage_table(data_manager, search_condition): | |||
| """Get lineage data in a table from data manager.""" | |||
| lineages = filter_summary_lineage(data_manager=data_manager) | |||
| summary_base_dir = data_manager.summary_base_dir | |||
| lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition) | |||
| lineage_objects = lineages.get("object", []) | |||
| cnt_lineages = len(lineage_objects) | |||
| metric_prefix = _METRIC_PREFIX | |||
| user_defined_prefix = _USER_DEFINED_PREFIX | |||
| # Step 1, get column names | |||
| column_names = _get_columns_name(lineage_objects) | |||
| # Step 2, collect data | |||
| column_data = _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir) | |||
| return LineageTable(pd.DataFrame(column_data)) | |||
| def _get_columns_name(lineage_objects): | |||
| """Get columns name.""" | |||
| column_names = set() | |||
| user_defined_num = 0 | |||
| for lineage in lineage_objects: | |||
| model_lineage = lineage.get("model_lineage", {}) | |||
| metric = model_lineage.get("metric", {}) | |||
| metric_names = tuple('{}{}'.format(metric_prefix, key) for key in metric.keys()) | |||
| metric_names = tuple('{}{}'.format(_METRIC_PREFIX, key) for key in metric.keys()) | |||
| user_defined = model_lineage.get("user_defined", {}) | |||
| user_defined_names = tuple('{}{}'.format(metric_prefix, key) for key in user_defined.keys()) | |||
| user_defined_names = tuple('{}{}'.format(_USER_DEFINED_PREFIX, key) for key in user_defined.keys()) | |||
| model_lineage_temp = list(model_lineage.keys()) | |||
| for key in model_lineage_temp: | |||
| if key in ["metric", "user_defined"]: | |||
| model_lineage_temp.remove(key) | |||
| column_names.update(model_lineage_temp) | |||
| column_names.update(metric_names) | |||
| column_names.update(user_defined_names) | |||
| # Step 2, collect data | |||
| if user_defined_num + len(user_defined_names) <= USER_DEFINED_INFO_LIMIT: | |||
| column_names.update(user_defined_names) | |||
| user_defined_num += len(user_defined_names) | |||
| elif user_defined_num < USER_DEFINED_INFO_LIMIT <= user_defined_num + len(user_defined_names): | |||
| names = [] | |||
| for i in range(USER_DEFINED_INFO_LIMIT - user_defined_num): | |||
| names.append(user_defined_names[i]) | |||
| column_names.update(names) | |||
| user_defined_num += len(names) | |||
| log.info("Partial user_defined_info is deleted. Currently saved length is: %s.", user_defined_num) | |||
| else: | |||
| log.info("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT) | |||
| column_names.update(["train_id"]) | |||
| return column_names | |||
| def _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir): | |||
| """Collect data and transform to matrix.""" | |||
| cnt_lineages = len(lineage_objects) | |||
| column_data = {key: [None] * cnt_lineages for key in column_names} | |||
| for ind, lineage in enumerate(lineage_objects): | |||
| train_id = get_relative_path(lineage.get("summary_dir"), summary_base_dir) | |||
| model_lineage = lineage.get("model_lineage", {}) | |||
| metric = model_lineage.pop("metric", {}) | |||
| metric_content = { | |||
| '{}{}'.format(metric_prefix, key): val for key, val in metric.items() | |||
| '{}{}'.format(_METRIC_PREFIX, key): val for key, val in metric.items() | |||
| } | |||
| user_defined = model_lineage.pop("user_defined", {}) | |||
| user_defined_content = { | |||
| '{}{}'.format(user_defined_prefix, key): val for key, val in user_defined.items() | |||
| '{}{}'.format(_USER_DEFINED_PREFIX, key): val for key, val in user_defined.items() | |||
| } | |||
| final_content = {} | |||
| final_content.update(model_lineage) | |||
| final_content.update(metric_content) | |||
| final_content.update(user_defined_content) | |||
| final_content.update({"train_id": train_id}) | |||
| for key, val in final_content.items(): | |||
| column_data[key][ind] = val | |||
| return pd.DataFrame(column_data) | |||
| if isinstance(val, str) and val.lower() in ['nan', 'inf']: | |||
| val = np.nan | |||
| if key in column_data: | |||
| column_data[key][ind] = val | |||
| return column_data | |||
| class LineageTable: | |||
| """Wrap lineage data in a table.""" | |||
| _LOSS_NAME = "loss" | |||
| _NOT_TUNABLE_NAMES = [_LOSS_NAME, "train_id", "device_num", "model_size", | |||
| "test_dataset_count", "train_dataset_count"] | |||
| def __init__(self, df: pd.DataFrame): | |||
| self._df = df | |||
| self.train_ids = self._df["train_id"].tolist() | |||
| self._drop_columns_info = [] | |||
| self._remove_unsupported_columns() | |||
| def _remove_unsupported_columns(self): | |||
| """Remove unsupported columns.""" | |||
| columns_to_drop = [] | |||
| for name, data in self._df.iteritems(): | |||
| if not is_simple_numpy_number(data.dtype): | |||
| columns_to_drop.append(name) | |||
| if columns_to_drop: | |||
| log.debug("Unsupported columns: %s", columns_to_drop) | |||
| self._df = self._df.drop(columns=columns_to_drop) | |||
| for name in columns_to_drop: | |||
| if not name.startswith(_USER_DEFINED_PREFIX): | |||
| continue | |||
| self._drop_columns_info.append({ | |||
| "name": name, | |||
| "unselected": True, | |||
| "reason_code": ReasonCode.NOT_ALL_NUMBERS.value | |||
| }) | |||
| @property | |||
| def target_names(self): | |||
| """Get names for optimize targets (eg loss, accuracy).""" | |||
| target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)] | |||
| if self._LOSS_NAME in self._df.columns: | |||
| target_names.append(self._LOSS_NAME) | |||
| return target_names | |||
| @property | |||
| def hyper_param_names(self, tunable=True): | |||
| """Get hyper param names.""" | |||
| blocked_names = self._get_blocked_names(tunable) | |||
| hyper_param_names = [ | |||
| name for name in self._df.columns | |||
| if not name.startswith(_METRIC_PREFIX) and name not in blocked_names] | |||
| if self._LOSS_NAME in hyper_param_names: | |||
| hyper_param_names.remove(self._LOSS_NAME) | |||
| return hyper_param_names | |||
| def _get_blocked_names(self, tunable): | |||
| if tunable: | |||
| block_names = self._NOT_TUNABLE_NAMES | |||
| else: | |||
| block_names = [] | |||
| return block_names | |||
| @property | |||
| def user_defined_hyper_param_names(self): | |||
| """Get user defined hyper param names.""" | |||
| names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)] | |||
| return names | |||
| def get_column(self, name): | |||
| """ | |||
| Get data for specified column. | |||
| Args: | |||
| name (str): column name. | |||
| Returns: | |||
| np.ndarray, specified column. | |||
| """ | |||
| return self._df[name] | |||
| def get_column_values(self, name): | |||
| """ | |||
| Get data for specified column. | |||
| Args: | |||
| name (str): column name. | |||
| Returns: | |||
| list, specified column data. If value is np.nan, transform to None. | |||
| """ | |||
| return [None if np.isnan(num) else num for num in self._df[name].tolist()] | |||
| @property | |||
| def df(self): | |||
| """Get the DataFrame.""" | |||
| return self._df | |||
| @property | |||
| def drop_column_info(self): | |||
| """Get dropped columns info.""" | |||
| return self._drop_columns_info | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Enums.""" | |||
| import enum | |||
| class BaseEnum(enum.Enum): | |||
| """Base enum.""" | |||
| @classmethod | |||
| def list_members(cls): | |||
| """List all members.""" | |||
| return [member.value for member in cls] | |||
| class ReasonCode(BaseEnum): | |||
| """Reason code for calculating importance.""" | |||
| NOT_ALL_NUMBERS = 1 | |||
| SAMPLES_NOT_ENOUGH = 2 | |||
| CORRELATION_NAN = 3 | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define custom exception.""" | |||
| from mindinsight.utils.constant import OptimizerErrors | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| class SamplesNotEnoughError(MindInsightException): | |||
| """Param importance calculated error.""" | |||
| def __init__(self, error_msg="Param importance calculated error."): | |||
| super(SamplesNotEnoughError, self).__init__(OptimizerErrors.SAMPLES_NOT_ENOUGH, | |||
| error_msg, | |||
| http_code=400) | |||
| class CorrelationNanError(MindInsightException): | |||
| """Param importance calculated error.""" | |||
| def __init__(self, error_msg="Param importance calculated error."): | |||
| super(CorrelationNanError, self).__init__(OptimizerErrors.CORRELATION_NAN, | |||
| error_msg, | |||
| http_code=400) | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Logger""" | |||
| from mindinsight.utils.log import setup_logger | |||
| logger = setup_logger("optimizer", "optimizer") | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for calculate importance.""" | |||
| import numpy as np | |||
| from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError | |||
| from mindinsight.optimizer.common.log import logger | |||
| def calc_hyper_param_importance(df, hyper_param, target): | |||
| """Calc hyper param importance relative to given target.""" | |||
| logger.debug("Calculating importance for hyper_param %s, target is %s.", hyper_param, target) | |||
| new_df = df[[hyper_param, target]] | |||
| no_missing_value_df = new_df.dropna() | |||
| # Can not calc pearson correlation coefficient when number of samples is less or equal than 2 | |||
| if len(no_missing_value_df) <= 2: | |||
| raise SamplesNotEnoughError("Number of samples is less or equal than 2.") | |||
| correlation = no_missing_value_df[target].corr(no_missing_value_df[hyper_param]) | |||
| if np.isnan(correlation): | |||
| logger.warning("Correlation is nan!") | |||
| raise CorrelationNanError("Correlation is nan!") | |||
| return abs(correlation) | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for optimizer.""" | |||
| import numpy as np | |||
| _DEFAULT_HISTOGRAM_BINS = 5 | |||
| def calc_histogram(np_value: np.ndarray, bins=_DEFAULT_HISTOGRAM_BINS): | |||
| """ | |||
| Calculates histogram. | |||
| This is a simple wrapper around the error-prone np.histogram() to improve robustness. | |||
| """ | |||
| ma_value = np.ma.masked_invalid(np_value) | |||
| valid_cnt = ma_value.count() | |||
| if not valid_cnt: | |||
| max_val = 0 | |||
| min_val = 0 | |||
| else: | |||
| # Note that max of a masked array with dtype np.float16 returns inf (numpy issue#15077). | |||
| if np.issubdtype(np_value.dtype, np.floating): | |||
| max_val = ma_value.max(fill_value=np.NINF) | |||
| min_val = ma_value.min(fill_value=np.PINF) | |||
| else: | |||
| max_val = ma_value.max() | |||
| min_val = ma_value.min() | |||
| range_left = min_val | |||
| range_right = max_val | |||
| if range_left >= range_right: | |||
| range_left -= 0.5 | |||
| range_right += 0.5 | |||
| with np.errstate(invalid='ignore'): | |||
| # if don't ignore state above, when np.nan exists, | |||
| # it will occur RuntimeWarning: invalid value encountered in less_equal | |||
| counts, edges = np.histogram(np_value, bins=bins, range=(range_left, range_right)) | |||
| histogram_bins = [None] * len(counts) | |||
| for ind, count in enumerate(counts): | |||
| histogram_bins[ind] = [float(edges[ind]), float(edges[ind + 1] - edges[ind]), float(count)] | |||
| return histogram_bins | |||
| def is_simple_numpy_number(dtype): | |||
| """Verify if it is simple number.""" | |||
| if np.issubdtype(dtype, np.integer): | |||
| return True | |||
| if np.issubdtype(dtype, np.floating): | |||
| return True | |||
| return False | |||
| @@ -32,6 +32,7 @@ class MindInsightModules(Enum): | |||
| PROFILERMGR = 6 | |||
| SCRIPTCONVERTER = 7 | |||
| WIZARD = 9 | |||
| OPTIMIZER = 10 | |||
| class GeneralErrors(Enum): | |||
| @@ -81,5 +82,12 @@ class DataVisualErrors(Enum): | |||
| class ScriptConverterErrors(Enum): | |||
| """Enum definition for mindconverter errors.""" | |||
| class WizardErrors(Enum): | |||
| """Enum definition for mindwizard errors.""" | |||
| class OptimizerErrors(Enum): | |||
| """Enum definition for optimizer errors.""" | |||
| SAMPLES_NOT_ENOUGH = 1 | |||
| CORRELATION_NAN = 2 | |||