Browse Source

add optimizer:

1. add functions for calculating target buckets and params importances
2. add restful api for target buckets, importances and metadata
tags/v1.0.0
luopengting 5 years ago
parent
commit
b274c774ed
12 changed files with 535 additions and 16 deletions
  1. +26
    -0
      mindinsight/backend/optimizer/__init__.py
  2. +101
    -0
      mindinsight/backend/optimizer/optimizer_api.py
  3. +25
    -2
      mindinsight/lineagemgr/common/utils.py
  4. +156
    -14
      mindinsight/lineagemgr/model.py
  5. +14
    -0
      mindinsight/optimizer/__init__.py
  6. +32
    -0
      mindinsight/optimizer/common/enums.py
  7. +34
    -0
      mindinsight/optimizer/common/exceptions.py
  8. +19
    -0
      mindinsight/optimizer/common/log.py
  9. +14
    -0
      mindinsight/optimizer/utils/__init__.py
  10. +37
    -0
      mindinsight/optimizer/utils/importances.py
  11. +69
    -0
      mindinsight/optimizer/utils/utils.py
  12. +8
    -0
      mindinsight/utils/constant.py

+ 26
- 0
mindinsight/backend/optimizer/__init__.py View File

@@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Optimizer API module."""
from mindinsight.backend.optimizer.optimizer_api import init_module as init_optimizer_model


def init_module(app):
"""
Init module entry.

Args:
app: Flask. A Flask instance.
"""
init_optimizer_model(app)

+ 101
- 0
mindinsight/backend/optimizer/optimizer_api.py View File

@@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Optimizer API module."""
import json
from flask import Blueprint, jsonify, request

from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.lineagemgr.model import get_lineage_table
from mindinsight.optimizer.common.enums import ReasonCode
from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
from mindinsight.optimizer.utils.utils import calc_histogram
from mindinsight.utils.exceptions import ParamValueError

BLUEPRINT = Blueprint("optimizer", __name__, url_prefix=settings.URL_PATH_PREFIX+settings.API_PREFIX)


@BLUEPRINT.route("/optimizer/targets/search", methods=["POST"])
def get_optimize_targets():
"""Get optimize targets."""
search_condition = request.stream.read()
try:
search_condition = json.loads(search_condition if search_condition else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")

response = _get_optimize_targets(DATA_MANAGER, search_condition)
return jsonify(response)


def _get_optimize_targets(data_manager, search_condition):
"""Get optimize targets."""
table = get_lineage_table(data_manager, search_condition)

target_summaries = []
for target in table.target_names:
hyper_parameters = []
for hyper_param in table.hyper_param_names:
param_info = {"name": hyper_param}
try:
importance = calc_hyper_param_importance(table.df, hyper_param, target)
param_info.update({"importance": importance})
except SamplesNotEnoughError:
param_info.update({"importance": 0})
param_info.update({"reason_code": ReasonCode.SAMPLES_NOT_ENOUGH.value})
except CorrelationNanError:
param_info.update({"importance": 0})
param_info.update({"reason_code": ReasonCode.CORRELATION_NAN.value})
hyper_parameters.append(param_info)

hyper_parameters.sort(key=lambda hyper_param: hyper_param.get("importance"), reverse=True)

target_summary = {
"name": target,
"buckets": calc_histogram(table.get_column(target)),
"hyper_parameters": hyper_parameters,
"data": table.get_column_values(target)
}
target_summaries.append(target_summary)

target_summaries.sort(key=lambda summary: summary.get("name"))

hyper_params_metadata = [{
"name": hyper_param,
"data": table.get_column_values(hyper_param)
} for hyper_param in table.hyper_param_names]

result = {
"metadata": {
"train_ids": table.train_ids,
"possible_hyper_parameters": hyper_params_metadata,
"unrecognized_params": table.drop_column_info
},
"targets": target_summaries
}

return result


def init_module(app):
"""
Init module entry.

Args:
app: the application obj.

"""
app.register_blueprint(BLUEPRINT)

+ 25
- 2
mindinsight/lineagemgr/common/utils.py View File

@@ -16,6 +16,7 @@
import os import os
import re import re
from functools import wraps from functools import wraps
from pathlib import Path


from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
@@ -99,7 +100,29 @@ def make_directory(path):
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
real_path = path real_path = path
except PermissionError as e:
log.error("No write permission on the directory(%r), error = %r", path, e)
except PermissionError as err:
log.error("No write permission on the directory(%r), error = %r", path, err)
raise LineageParamTypeError("No write permission on the directory.") raise LineageParamTypeError("No write permission on the directory.")
return real_path return real_path


def get_relative_path(path, base_path):
"""
Get relative path based on base_path.

Args:
path (str): absolute path.
base_path: absolute base path.

Returns:
str, relative path based on base_path.

"""
try:
r_path = str(Path(path).relative_to(Path(base_path)))
except ValueError:
raise LineageParamValueError("The path %r does not start with %r." % (path, base_path))

if r_path == ".":
r_path = ""
return os.path.join("./", r_path)

+ 156
- 14
mindinsight/lineagemgr/model.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""This file is used to define the model lineage python api.""" """This file is used to define the model lineage python api."""
import os import os
import numpy as np
import pandas as pd import pandas as pd


from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValueError, \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValueError, \
@@ -21,17 +22,21 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamValu
LineageQuerierParamException, LineageDirNotExistError, LineageSearchConditionParamError, \ LineageQuerierParamException, LineageDirNotExistError, LineageSearchConditionParamError, \
LineageParamTypeError, LineageSummaryParseException LineageParamTypeError, LineageSummaryParseException
from mindinsight.lineagemgr.common.log import logger as log from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.utils import normalize_summary_dir
from mindinsight.lineagemgr.common.utils import normalize_summary_dir, get_relative_path
from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import validate_filter_key, validate_search_model_condition, \ from mindinsight.lineagemgr.common.validator.validate import validate_filter_key, validate_search_model_condition, \
validate_condition, validate_path, validate_train_id validate_condition, validate_path, validate_train_id
from mindinsight.lineagemgr.lineage_parser import LineageParser, LineageOrganizer from mindinsight.lineagemgr.lineage_parser import LineageParser, LineageOrganizer
from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.optimizer.common.enums import ReasonCode
from mindinsight.optimizer.utils.utils import is_simple_numpy_number
from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import MindInsightException


_METRIC_PREFIX = "[M]" _METRIC_PREFIX = "[M]"
_USER_DEFINED_PREFIX = "[U]" _USER_DEFINED_PREFIX = "[U]"


USER_DEFINED_INFO_LIMIT = 100



def get_summary_lineage(data_manager=None, summary_dir=None, keys=None): def get_summary_lineage(data_manager=None, summary_dir=None, keys=None):
""" """
@@ -189,44 +194,181 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition):
return search_condition return search_condition




def get_lineage_table(data_manager):
def get_lineage_table(data_manager, search_condition):
"""Get lineage data in a table from data manager.""" """Get lineage data in a table from data manager."""
lineages = filter_summary_lineage(data_manager=data_manager)
summary_base_dir = data_manager.summary_base_dir
lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition)
lineage_objects = lineages.get("object", []) lineage_objects = lineages.get("object", [])
cnt_lineages = len(lineage_objects)
metric_prefix = _METRIC_PREFIX
user_defined_prefix = _USER_DEFINED_PREFIX

# Step 1, get column names # Step 1, get column names
column_names = _get_columns_name(lineage_objects)

# Step 2, collect data
column_data = _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir)

return LineageTable(pd.DataFrame(column_data))


def _get_columns_name(lineage_objects):
"""Get columns name."""
column_names = set() column_names = set()
user_defined_num = 0
for lineage in lineage_objects: for lineage in lineage_objects:
model_lineage = lineage.get("model_lineage", {}) model_lineage = lineage.get("model_lineage", {})
metric = model_lineage.get("metric", {}) metric = model_lineage.get("metric", {})
metric_names = tuple('{}{}'.format(metric_prefix, key) for key in metric.keys())
metric_names = tuple('{}{}'.format(_METRIC_PREFIX, key) for key in metric.keys())
user_defined = model_lineage.get("user_defined", {}) user_defined = model_lineage.get("user_defined", {})
user_defined_names = tuple('{}{}'.format(metric_prefix, key) for key in user_defined.keys())
user_defined_names = tuple('{}{}'.format(_USER_DEFINED_PREFIX, key) for key in user_defined.keys())
model_lineage_temp = list(model_lineage.keys()) model_lineage_temp = list(model_lineage.keys())
for key in model_lineage_temp: for key in model_lineage_temp:
if key in ["metric", "user_defined"]: if key in ["metric", "user_defined"]:
model_lineage_temp.remove(key) model_lineage_temp.remove(key)
column_names.update(model_lineage_temp) column_names.update(model_lineage_temp)
column_names.update(metric_names) column_names.update(metric_names)
column_names.update(user_defined_names)
# Step 2, collect data
if user_defined_num + len(user_defined_names) <= USER_DEFINED_INFO_LIMIT:
column_names.update(user_defined_names)
user_defined_num += len(user_defined_names)
elif user_defined_num < USER_DEFINED_INFO_LIMIT <= user_defined_num + len(user_defined_names):
names = []
for i in range(USER_DEFINED_INFO_LIMIT - user_defined_num):
names.append(user_defined_names[i])
column_names.update(names)
user_defined_num += len(names)
log.info("Partial user_defined_info is deleted. Currently saved length is: %s.", user_defined_num)
else:
log.info("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT)
column_names.update(["train_id"])

return column_names


def _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir):
"""Collect data and transform to matrix."""
cnt_lineages = len(lineage_objects)
column_data = {key: [None] * cnt_lineages for key in column_names} column_data = {key: [None] * cnt_lineages for key in column_names}
for ind, lineage in enumerate(lineage_objects): for ind, lineage in enumerate(lineage_objects):

train_id = get_relative_path(lineage.get("summary_dir"), summary_base_dir)

model_lineage = lineage.get("model_lineage", {}) model_lineage = lineage.get("model_lineage", {})
metric = model_lineage.pop("metric", {}) metric = model_lineage.pop("metric", {})
metric_content = { metric_content = {
'{}{}'.format(metric_prefix, key): val for key, val in metric.items()
'{}{}'.format(_METRIC_PREFIX, key): val for key, val in metric.items()
} }
user_defined = model_lineage.pop("user_defined", {}) user_defined = model_lineage.pop("user_defined", {})
user_defined_content = { user_defined_content = {
'{}{}'.format(user_defined_prefix, key): val for key, val in user_defined.items()
'{}{}'.format(_USER_DEFINED_PREFIX, key): val for key, val in user_defined.items()
} }
final_content = {} final_content = {}
final_content.update(model_lineage) final_content.update(model_lineage)
final_content.update(metric_content) final_content.update(metric_content)
final_content.update(user_defined_content) final_content.update(user_defined_content)
final_content.update({"train_id": train_id})
for key, val in final_content.items(): for key, val in final_content.items():
column_data[key][ind] = val
return pd.DataFrame(column_data)
if isinstance(val, str) and val.lower() in ['nan', 'inf']:
val = np.nan
if key in column_data:
column_data[key][ind] = val
return column_data


class LineageTable:
"""Wrap lineage data in a table."""
_LOSS_NAME = "loss"
_NOT_TUNABLE_NAMES = [_LOSS_NAME, "train_id", "device_num", "model_size",
"test_dataset_count", "train_dataset_count"]

def __init__(self, df: pd.DataFrame):
self._df = df
self.train_ids = self._df["train_id"].tolist()
self._drop_columns_info = []
self._remove_unsupported_columns()

def _remove_unsupported_columns(self):
"""Remove unsupported columns."""
columns_to_drop = []
for name, data in self._df.iteritems():
if not is_simple_numpy_number(data.dtype):
columns_to_drop.append(name)

if columns_to_drop:
log.debug("Unsupported columns: %s", columns_to_drop)
self._df = self._df.drop(columns=columns_to_drop)

for name in columns_to_drop:
if not name.startswith(_USER_DEFINED_PREFIX):
continue
self._drop_columns_info.append({
"name": name,
"unselected": True,
"reason_code": ReasonCode.NOT_ALL_NUMBERS.value
})

@property
def target_names(self):
"""Get names for optimize targets (eg loss, accuracy)."""
target_names = [name for name in self._df.columns if name.startswith(_METRIC_PREFIX)]
if self._LOSS_NAME in self._df.columns:
target_names.append(self._LOSS_NAME)
return target_names

@property
def hyper_param_names(self, tunable=True):
"""Get hyper param names."""
blocked_names = self._get_blocked_names(tunable)

hyper_param_names = [
name for name in self._df.columns
if not name.startswith(_METRIC_PREFIX) and name not in blocked_names]

if self._LOSS_NAME in hyper_param_names:
hyper_param_names.remove(self._LOSS_NAME)

return hyper_param_names

def _get_blocked_names(self, tunable):
if tunable:
block_names = self._NOT_TUNABLE_NAMES
else:
block_names = []
return block_names

@property
def user_defined_hyper_param_names(self):
"""Get user defined hyper param names."""
names = [name for name in self._df.columns if name.startswith(_USER_DEFINED_PREFIX)]
return names

def get_column(self, name):
"""
Get data for specified column.
Args:
name (str): column name.

Returns:
np.ndarray, specified column.

"""
return self._df[name]

def get_column_values(self, name):
"""
Get data for specified column.
Args:
name (str): column name.

Returns:
list, specified column data. If value is np.nan, transform to None.

"""
return [None if np.isnan(num) else num for num in self._df[name].tolist()]

@property
def df(self):
"""Get the DataFrame."""
return self._df

@property
def drop_column_info(self):
"""Get dropped columns info."""
return self._drop_columns_info

+ 14
- 0
mindinsight/optimizer/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 32
- 0
mindinsight/optimizer/common/enums.py View File

@@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""

import enum


class BaseEnum(enum.Enum):
"""Base enum."""
@classmethod
def list_members(cls):
"""List all members."""
return [member.value for member in cls]


class ReasonCode(BaseEnum):
"""Reason code for calculating importance."""
NOT_ALL_NUMBERS = 1
SAMPLES_NOT_ENOUGH = 2
CORRELATION_NAN = 3

+ 34
- 0
mindinsight/optimizer/common/exceptions.py View File

@@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""

from mindinsight.utils.constant import OptimizerErrors
from mindinsight.utils.exceptions import MindInsightException


class SamplesNotEnoughError(MindInsightException):
"""Param importance calculated error."""
def __init__(self, error_msg="Param importance calculated error."):
super(SamplesNotEnoughError, self).__init__(OptimizerErrors.SAMPLES_NOT_ENOUGH,
error_msg,
http_code=400)


class CorrelationNanError(MindInsightException):
"""Param importance calculated error."""
def __init__(self, error_msg="Param importance calculated error."):
super(CorrelationNanError, self).__init__(OptimizerErrors.CORRELATION_NAN,
error_msg,
http_code=400)

+ 19
- 0
mindinsight/optimizer/common/log.py View File

@@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Logger"""

from mindinsight.utils.log import setup_logger

logger = setup_logger("optimizer", "optimizer")

+ 14
- 0
mindinsight/optimizer/utils/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 37
- 0
mindinsight/optimizer/utils/importances.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for calculate importance."""
import numpy as np

from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
from mindinsight.optimizer.common.log import logger


def calc_hyper_param_importance(df, hyper_param, target):
"""Calc hyper param importance relative to given target."""
logger.debug("Calculating importance for hyper_param %s, target is %s.", hyper_param, target)

new_df = df[[hyper_param, target]]
no_missing_value_df = new_df.dropna()

# Can not calc pearson correlation coefficient when number of samples is less or equal than 2
if len(no_missing_value_df) <= 2:
raise SamplesNotEnoughError("Number of samples is less or equal than 2.")

correlation = no_missing_value_df[target].corr(no_missing_value_df[hyper_param])
if np.isnan(correlation):
logger.warning("Correlation is nan!")
raise CorrelationNanError("Correlation is nan!")
return abs(correlation)

+ 69
- 0
mindinsight/optimizer/utils/utils.py View File

@@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for optimizer."""
import numpy as np

_DEFAULT_HISTOGRAM_BINS = 5


def calc_histogram(np_value: np.ndarray, bins=_DEFAULT_HISTOGRAM_BINS):
"""
Calculates histogram.

This is a simple wrapper around the error-prone np.histogram() to improve robustness.
"""
ma_value = np.ma.masked_invalid(np_value)

valid_cnt = ma_value.count()
if not valid_cnt:
max_val = 0
min_val = 0
else:
# Note that max of a masked array with dtype np.float16 returns inf (numpy issue#15077).
if np.issubdtype(np_value.dtype, np.floating):
max_val = ma_value.max(fill_value=np.NINF)
min_val = ma_value.min(fill_value=np.PINF)
else:
max_val = ma_value.max()
min_val = ma_value.min()

range_left = min_val
range_right = max_val

if range_left >= range_right:
range_left -= 0.5
range_right += 0.5

with np.errstate(invalid='ignore'):
# if don't ignore state above, when np.nan exists,
# it will occur RuntimeWarning: invalid value encountered in less_equal
counts, edges = np.histogram(np_value, bins=bins, range=(range_left, range_right))

histogram_bins = [None] * len(counts)
for ind, count in enumerate(counts):
histogram_bins[ind] = [float(edges[ind]), float(edges[ind + 1] - edges[ind]), float(count)]

return histogram_bins


def is_simple_numpy_number(dtype):
"""Verify if it is simple number."""
if np.issubdtype(dtype, np.integer):
return True

if np.issubdtype(dtype, np.floating):
return True

return False

+ 8
- 0
mindinsight/utils/constant.py View File

@@ -32,6 +32,7 @@ class MindInsightModules(Enum):
PROFILERMGR = 6 PROFILERMGR = 6
SCRIPTCONVERTER = 7 SCRIPTCONVERTER = 7
WIZARD = 9 WIZARD = 9
OPTIMIZER = 10




class GeneralErrors(Enum): class GeneralErrors(Enum):
@@ -81,5 +82,12 @@ class DataVisualErrors(Enum):
class ScriptConverterErrors(Enum): class ScriptConverterErrors(Enum):
"""Enum definition for mindconverter errors.""" """Enum definition for mindconverter errors."""



class WizardErrors(Enum): class WizardErrors(Enum):
"""Enum definition for mindwizard errors.""" """Enum definition for mindwizard errors."""


class OptimizerErrors(Enum):
"""Enum definition for optimizer errors."""
SAMPLES_NOT_ENOUGH = 1
CORRELATION_NAN = 2

Loading…
Cancel
Save