| @@ -14,11 +14,12 @@ | |||
| # ============================================================================ | |||
| """Optimizer API module.""" | |||
| import json | |||
| import pandas as pd | |||
| from flask import Blueprint, jsonify, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | |||
| from mindinsight.lineagemgr.model import get_lineage_table | |||
| from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable | |||
| from mindinsight.optimizer.common.enums import ReasonCode | |||
| from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError | |||
| from mindinsight.optimizer.utils.importances import calc_hyper_param_importance | |||
| @@ -41,9 +42,10 @@ def get_optimize_targets(): | |||
| return jsonify(response) | |||
| def _get_optimize_targets(data_manager, search_condition): | |||
| def _get_optimize_targets(data_manager, search_condition=None): | |||
| """Get optimize targets.""" | |||
| table = get_lineage_table(data_manager, search_condition) | |||
| flatten_lineage = get_flattened_lineage(data_manager, search_condition) | |||
| table = LineageTable(pd.DataFrame(flatten_lineage)) | |||
| target_summaries = [] | |||
| for target in table.target_names: | |||
| @@ -51,7 +53,7 @@ def _get_optimize_targets(data_manager, search_condition): | |||
| for hyper_param in table.hyper_param_names: | |||
| param_info = {"name": hyper_param} | |||
| try: | |||
| importance = calc_hyper_param_importance(table.df, hyper_param, target) | |||
| importance = calc_hyper_param_importance(table.dataframe_data, hyper_param, target) | |||
| param_info.update({"importance": importance}) | |||
| except SamplesNotEnoughError: | |||
| param_info.update({"importance": 0}) | |||
| @@ -194,8 +194,18 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition): | |||
| return search_condition | |||
| def get_lineage_table(data_manager, search_condition): | |||
| """Get lineage data in a table from data manager.""" | |||
| def get_flattened_lineage(data_manager, search_condition=None): | |||
| """ | |||
| Get lineage data in a table from data manager. | |||
| Args: | |||
| data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. | |||
| search_condition (dict): The search condition. | |||
| Returns: | |||
| Dict[str, list]: A dict contains keys and values from lineages. | |||
| """ | |||
| summary_base_dir = data_manager.summary_base_dir | |||
| lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition) | |||
| lineage_objects = lineages.get("object", []) | |||
| @@ -206,7 +216,7 @@ def get_lineage_table(data_manager, search_condition): | |||
| # Step 2, collect data | |||
| column_data = _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir) | |||
| return LineageTable(pd.DataFrame(column_data)) | |||
| return column_data | |||
| def _get_columns_name(lineage_objects): | |||
| @@ -236,7 +246,7 @@ def _get_columns_name(lineage_objects): | |||
| user_defined_num += len(names) | |||
| log.info("Partial user_defined_info is deleted. Currently saved length is: %s.", user_defined_num) | |||
| else: | |||
| log.info("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT) | |||
| log.warning("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT) | |||
| column_names.update(["train_id"]) | |||
| return column_names | |||
| @@ -364,7 +374,7 @@ class LineageTable: | |||
| return [None if np.isnan(num) else num for num in self._df[name].tolist()] | |||
| @property | |||
| def df(self): | |||
| def dataframe_data(self): | |||
| """Get the DataFrame.""" | |||
| return self._df | |||
| @@ -42,9 +42,10 @@ def calc_histogram(np_value: np.ndarray, bins=_DEFAULT_HISTOGRAM_BINS): | |||
| range_left = min_val | |||
| range_right = max_val | |||
| default_half_range = 0.5 | |||
| if range_left >= range_right: | |||
| range_left -= 0.5 | |||
| range_right += 0.5 | |||
| range_left -= default_half_range | |||
| range_right += default_half_range | |||
| with np.errstate(invalid='ignore'): | |||
| # if don't ignore state above, when np.nan exists, | |||