Browse Source

modify names of some variables and functions, modify some level of log

tags/v1.0.0
luopengting 5 years ago
parent
commit
b0d3da8415
3 changed files with 24 additions and 11 deletions
  1. +6
    -4
      mindinsight/backend/optimizer/optimizer_api.py
  2. +15
    -5
      mindinsight/lineagemgr/model.py
  3. +3
    -2
      mindinsight/optimizer/utils/utils.py

+ 6
- 4
mindinsight/backend/optimizer/optimizer_api.py View File

@@ -14,11 +14,12 @@
# ============================================================================
"""Optimizer API module."""
import json
import pandas as pd
from flask import Blueprint, jsonify, request

from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.lineagemgr.model import get_lineage_table
from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable
from mindinsight.optimizer.common.enums import ReasonCode
from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
@@ -41,9 +42,10 @@ def get_optimize_targets():
return jsonify(response)


def _get_optimize_targets(data_manager, search_condition):
def _get_optimize_targets(data_manager, search_condition=None):
"""Get optimize targets."""
table = get_lineage_table(data_manager, search_condition)
flatten_lineage = get_flattened_lineage(data_manager, search_condition)
table = LineageTable(pd.DataFrame(flatten_lineage))

target_summaries = []
for target in table.target_names:
@@ -51,7 +53,7 @@ def _get_optimize_targets(data_manager, search_condition):
for hyper_param in table.hyper_param_names:
param_info = {"name": hyper_param}
try:
importance = calc_hyper_param_importance(table.df, hyper_param, target)
importance = calc_hyper_param_importance(table.dataframe_data, hyper_param, target)
param_info.update({"importance": importance})
except SamplesNotEnoughError:
param_info.update({"importance": 0})


+ 15
- 5
mindinsight/lineagemgr/model.py View File

@@ -194,8 +194,18 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
return search_condition


def get_lineage_table(data_manager, search_condition):
"""Get lineage data in a table from data manager."""
def get_flattened_lineage(data_manager, search_condition=None):
"""
Get lineage data in a table from data manager.

Args:
data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading.
search_condition (dict): The search condition.

Returns:
Dict[str, list]: A dict contains keys and values from lineages.

"""
summary_base_dir = data_manager.summary_base_dir
lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition)
lineage_objects = lineages.get("object", [])
@@ -206,7 +216,7 @@ def get_lineage_table(data_manager, search_condition):
# Step 2, collect data
column_data = _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir)

return LineageTable(pd.DataFrame(column_data))
return column_data


def _get_columns_name(lineage_objects):
@@ -236,7 +246,7 @@ def _get_columns_name(lineage_objects):
user_defined_num += len(names)
log.info("Partial user_defined_info is deleted. Currently saved length is: %s.", user_defined_num)
else:
log.info("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT)
log.warning("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT)
column_names.update(["train_id"])

return column_names
@@ -364,7 +374,7 @@ class LineageTable:
return [None if np.isnan(num) else num for num in self._df[name].tolist()]

@property
def df(self):
def dataframe_data(self):
"""Get the DataFrame."""
return self._df



+ 3
- 2
mindinsight/optimizer/utils/utils.py View File

@@ -42,9 +42,10 @@ def calc_histogram(np_value: np.ndarray, bins=_DEFAULT_HISTOGRAM_BINS):
range_left = min_val
range_right = max_val

default_half_range = 0.5
if range_left >= range_right:
range_left -= 0.5
range_right += 0.5
range_left -= default_half_range
range_right += default_half_range

with np.errstate(invalid='ignore'):
# if don't ignore state above, when np.nan exists,


Loading…
Cancel
Save