Browse Source

add cli, validation and st:

1. add cli entry
2. add more validation for hyper config
3. add validator for config dict
4. add st for optimizer config validator
5. add custom lineage for hyper config
tags/v1.1.0
luopengting 5 years ago
parent
commit
0f37c8135e
26 changed files with 829 additions and 99 deletions
  1. +1
    -1
      mindinsight/backend/optimizer/optimizer_api.py
  2. +38
    -1
      mindinsight/lineagemgr/lineage_parser.py
  3. +18
    -6
      mindinsight/lineagemgr/model.py
  4. +101
    -0
      mindinsight/optimizer/cli.py
  5. +1
    -0
      mindinsight/optimizer/common/constants.py
  6. +27
    -0
      mindinsight/optimizer/common/enums.py
  7. +16
    -0
      mindinsight/optimizer/common/exceptions.py
  8. +15
    -0
      mindinsight/optimizer/common/validator/__init__.py
  9. +202
    -0
      mindinsight/optimizer/common/validator/optimizer_config.py
  10. +70
    -24
      mindinsight/optimizer/hyper_config.py
  11. +33
    -38
      mindinsight/optimizer/tuner.py
  12. +10
    -5
      mindinsight/optimizer/tuners/gp_tuner.py
  13. +38
    -7
      mindinsight/optimizer/utils/param_handler.py
  14. +8
    -2
      mindinsight/optimizer/utils/transformer.py
  15. +15
    -0
      mindinsight/optimizer/utils/utils.py
  16. +2
    -0
      mindinsight/utils/constant.py
  17. +2
    -2
      requirements.txt
  18. +1
    -0
      setup.py
  19. +3
    -3
      tests/st/func/lineagemgr/collection/model/test_model_lineage.py
  20. +21
    -9
      tests/st/func/lineagemgr/test_model.py
  21. +1
    -1
      tests/ut/lineagemgr/common/validator/test_validate.py
  22. +15
    -0
      tests/ut/optimizer/common/__init__.py
  23. +15
    -0
      tests/ut/optimizer/common/validator/__init__.py
  24. +161
    -0
      tests/ut/optimizer/common/validator/test_optimizer_config.py
  25. +15
    -0
      tests/ut/optimizer/utils/__init__.py
  26. +0
    -0
      tests/ut/optimizer/utils/test_utils.py

+ 1
- 1
mindinsight/backend/optimizer/optimizer_api.py View File

@@ -43,7 +43,7 @@ def get_optimize_targets():


def _get_optimize_targets(data_manager, search_condition=None): def _get_optimize_targets(data_manager, search_condition=None):
"""Get optimize targets.""" """Get optimize targets."""
table = get_lineage_table(data_manager, search_condition)
table = get_lineage_table(data_manager=data_manager, search_condition=search_condition)


target_summaries = [] target_summaries = []
for target in table.target_names: for target in table.target_names:


+ 38
- 1
mindinsight/lineagemgr/lineage_parser.py View File

@@ -15,6 +15,7 @@
"""This file is used to parse lineage info.""" """This file is used to parse lineage info."""
import os import os


from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \
LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \ LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \
MindInsightException MindInsightException
@@ -177,13 +178,49 @@ class LineageParser:


class LineageOrganizer: class LineageOrganizer:
"""Lineage organizer.""" """Lineage organizer."""
def __init__(self, data_manager):
def __init__(self, data_manager=None, summary_base_dir=None):
self._data_manager = data_manager self._data_manager = data_manager
self._summary_base_dir = summary_base_dir
self._check_params()
self._super_lineage_objs = {} self._super_lineage_objs = {}
self._organize_from_cache() self._organize_from_cache()
self._organize_from_disk()

def _check_params(self):
"""Check params."""
if self._data_manager is not None and self._summary_base_dir is not None:
self._summary_base_dir = None

def _organize_from_disk(self):
"""Organize lineage objs from disk."""
if self._summary_base_dir is None:
return
summary_watcher = SummaryWatcher()
relative_dirs = summary_watcher.list_summary_directories(
summary_base_dir=self._summary_base_dir
)

no_lineage_count = 0
for item in relative_dirs:
relative_dir = item.get('relative_path')
update_time = item.get('update_time')
abs_summary_dir = os.path.realpath(os.path.join(self._summary_base_dir, relative_dir))

try:
lineage_parser = LineageParser(relative_dir, abs_summary_dir, update_time)
super_lineage_obj = lineage_parser.super_lineage_obj
if super_lineage_obj is not None:
self._super_lineage_objs.update({abs_summary_dir: super_lineage_obj})
except LineageFileNotFoundError:
no_lineage_count += 1

if no_lineage_count == len(relative_dirs):
logger.info('There is no summary log file under summary_base_dir.')


def _organize_from_cache(self): def _organize_from_cache(self):
"""Organize lineage objs from cache.""" """Organize lineage objs from cache."""
if self._data_manager is None:
return
brief_cache = self._data_manager.get_brief_cache() brief_cache = self._data_manager.get_brief_cache()
cache_items = brief_cache.cache_items cache_items = brief_cache.cache_items
for relative_dir, cache_train_job in cache_items.items(): for relative_dir, cache_train_job in cache_items.items():


+ 18
- 6
mindinsight/lineagemgr/model.py View File

@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySumm
from mindinsight.lineagemgr.common.log import logger as log from mindinsight.lineagemgr.common.log import logger as log
from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition
from mindinsight.lineagemgr.common.validator.validate_path import validate_and_normalize_path
from mindinsight.lineagemgr.lineage_parser import LineageOrganizer from mindinsight.lineagemgr.lineage_parser import LineageOrganizer
from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.optimizer.common.enums import ReasonCode from mindinsight.optimizer.common.enums import ReasonCode
@@ -33,7 +34,7 @@ USER_DEFINED_PREFIX = "[U]"
USER_DEFINED_INFO_LIMIT = 100 USER_DEFINED_INFO_LIMIT = 100




def filter_summary_lineage(data_manager, search_condition=None):
def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None):
""" """
Filter summary lineage from data_manager or parsing from summaries. Filter summary lineage from data_manager or parsing from summaries.


@@ -43,8 +44,18 @@ def filter_summary_lineage(data_manager, search_condition=None):
Args: Args:
data_manager (DataManager): Data manager defined as data_manager (DataManager): Data manager defined as
mindinsight.datavisual.data_transform.data_manager.DataManager mindinsight.datavisual.data_transform.data_manager.DataManager
summary_base_dir (str): The summary base directory. It contains summary
directories generated by training.
search_condition (dict): The search condition. search_condition (dict): The search condition.
""" """
if data_manager is None and summary_base_dir is None:
raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.")

if data_manager is None:
summary_base_dir = validate_and_normalize_path(summary_base_dir, 'summary_base_dir')
else:
summary_base_dir = data_manager.summary_base_dir

search_condition = {} if search_condition is None else search_condition search_condition = {} if search_condition is None else search_condition


try: try:
@@ -56,7 +67,7 @@ def filter_summary_lineage(data_manager, search_condition=None):
raise LineageSearchConditionParamError(str(error.message)) raise LineageSearchConditionParamError(str(error.message))


try: try:
lineage_objects = LineageOrganizer(data_manager).super_lineage_objs
lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs
result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition) result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition)
except LineageSummaryParseException: except LineageSummaryParseException:
result = {'object': [], 'count': 0} result = {'object': [], 'count': 0}
@@ -68,12 +79,13 @@ def filter_summary_lineage(data_manager, search_condition=None):
return result return result




def get_flattened_lineage(data_manager, search_condition=None):
def get_flattened_lineage(data_manager=None, summary_base_dir=None, search_condition=None):
""" """
Get lineage data in a table from data manager. Get lineage data in a table from data manager.


Args: Args:
data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading.
summary_base_dir (str): The base directory for train jobs.
search_condition (dict): The search condition. search_condition (dict): The search condition.


Returns: Returns:
@@ -81,7 +93,7 @@ def get_flattened_lineage(data_manager, search_condition=None):


""" """
flatten_dict, user_count = {'train_id': []}, 0 flatten_dict, user_count = {'train_id': []}, 0
lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", [])
lineages = filter_summary_lineage(data_manager, summary_base_dir, search_condition).get("object", [])
for index, lineage in enumerate(lineages): for index, lineage in enumerate(lineages):
flatten_dict['train_id'].append(lineage.get("summary_dir")) flatten_dict['train_id'].append(lineage.get("summary_dir"))
for key, val in _flatten_lineage(lineage.get('model_lineage', {})): for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
@@ -222,9 +234,9 @@ class LineageTable:
return self._drop_columns_info return self._drop_columns_info




def get_lineage_table(data_manager, search_condition=None):
def get_lineage_table(data_manager=None, summary_base_dir=None, search_condition=None):
"""Get lineage table from data_manager.""" """Get lineage table from data_manager."""
lineage_table = get_flattened_lineage(data_manager, search_condition)
lineage_table = get_flattened_lineage(data_manager, summary_base_dir, search_condition)
lineage_table = LineageTable(pd.DataFrame(lineage_table)) lineage_table = LineageTable(pd.DataFrame(lineage_table))


return lineage_table return lineage_table

+ 101
- 0
mindinsight/optimizer/cli.py View File

@@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Command module."""
import argparse
import os
import sys

import mindinsight
from mindinsight.optimizer.tuner import Tuner


class ConfigAction(argparse.Action):
"""Summary base dir action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Option string for specific argument name.
"""
config_path = os.path.realpath(values)
if not os.path.exists(config_path):
parser_in.error(f'{option_string} {config_path} not exists.')

setattr(namespace, self.dest, config_path)


class IterAction(argparse.Action):
"""Summary base dir action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Option string for specific argument name.
"""
iter_times = values
if iter_times <= 0:
parser_in.error(f'{option_string} {iter_times} should be a positive integer.')

setattr(namespace, self.dest, iter_times)


parser = argparse.ArgumentParser(
prog='mindoptimizer',
description='MindOptimizer CLI entry point (version: {})'.format(mindinsight.__version__))

parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))

parser.add_argument(
'--config',
type=str,
action=ConfigAction,
required=True,
default=os.path.join(os.getcwd(), 'output'),
help="Specify path for config file."
)

parser.add_argument(
'--iter',
type=int,
action=IterAction,
default=1,
help="Optional, specify run times for the command in config file."
)


def cli_entry():
"""Cli entry."""
argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()

tuner = Tuner(args.config)
tuner.optimize(max_expr_times=args.iter)

+ 1
- 0
mindinsight/optimizer/common/constants.py View File

@@ -14,3 +14,4 @@
# ============================================================================ # ============================================================================
"""Common constants for optimizer.""" """Common constants for optimizer."""
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
HYPER_CONFIG_LEN_LIMIT = 100000

+ 27
- 0
mindinsight/optimizer/common/enums.py View File

@@ -48,12 +48,17 @@ class TuneMethod(BaseEnum):
GP = 'gp' GP = 'gp'




class GPSupportArgs(BaseEnum):
METHOD = 'method'


class HyperParamKey(BaseEnum): class HyperParamKey(BaseEnum):
"""Config keys for hyper parameters.""" """Config keys for hyper parameters."""
BOUND = 'bounds' BOUND = 'bounds'
CHOICE = 'choice' CHOICE = 'choice'
DECIMAL = 'decimal' DECIMAL = 'decimal'
TYPE = 'type' TYPE = 'type'
SOURCE = 'source'




class HyperParamType(BaseEnum): class HyperParamType(BaseEnum):
@@ -73,3 +78,25 @@ class TargetGoal(BaseEnum):
"""Goal for target.""" """Goal for target."""
MAXIMUM = 'maximize' MAXIMUM = 'maximize'
MINIMUM = 'minimize' MINIMUM = 'minimize'


class HyperParamSource(BaseEnum):
SYSTEM_DEFINED = 'system_defined'
USER_DEFINED = 'user_defined'


class TargetGroup(BaseEnum):
SYSTEM_DEFINED = 'system_defined'
METRIC = 'metric'


class TunableSystemDefinedParams(BaseEnum):
"""Tunable metadata keys of lineage collection."""
BATCH_SIZE = 'batch_size'
EPOCH = 'epoch'
LEARNING_RATE = 'learning_rate'


class SystemDefinedTargets(BaseEnum):
"""System defined targets"""
LOSS = 'loss'

+ 16
- 0
mindinsight/optimizer/common/exceptions.py View File

@@ -48,3 +48,19 @@ class OptimizerTerminateError(MindInsightException):
super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE,
error_msg, error_msg,
http_code=400) http_code=400)


class ConfigParamError(MindInsightException):
"""Hyper config error."""
def __init__(self, error_msg="Invalid parameter."):
super(ConfigParamError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE,
error_msg,
http_code=400)


class HyperConfigEnvError(MindInsightException):
"""Hyper config error."""
def __init__(self, error_msg="Hyper config is not correct."):
super(HyperConfigEnvError, self).__init__(OptimizerErrors.HYPER_CONFIG_ENV_ERROR,
error_msg,
http_code=400)

+ 15
- 0
mindinsight/optimizer/common/validator/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Common validator modules for optimizer."""

+ 202
- 0
mindinsight/optimizer/common/validator/optimizer_config.py View File

@@ -0,0 +1,202 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Validator for optimizer config."""
import math

from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema

from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \
HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \
HyperParamKey, SystemDefinedTargets

_BOUND_LEN = 2
_NUMBER_ERR_MSG = "Value(s) should be integer or float."
_TYPE_ERR_MSG = "Value type should be %r."
_VALUE_ERR_MSG = "Value should be in %s. Current value is %s."


def _generate_schema_err_msg(err_msg, *args):
"""Organize error messages."""
if args:
err_msg = err_msg % args
return {"invalid": err_msg}


def include_integer(low, high):
"""Check if the range [low, high) includes integer."""
def _in_range(num, low, high):
"""check if num in [low, high)"""
return low <= num < high

if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high):
return True
return False


class TunerSchema(Schema):
"""Schema for tuner."""
dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")

name = fields.Str(required=True,
validate=validate.OneOf(TuneMethod.list_members()),
error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members()))
args = fields.Dict(error_messages=dict_err_msg)

@validates("args")
def check_args(self, data):
"""Check args for tuner."""
data_keys = list(data.keys())
support_args = GPSupportArgs.list_members()
if not set(data_keys).issubset(set(support_args)):
raise ValidationError("Only support setting %s for tuner. "
"Current key(s): %s." % (support_args, data_keys))

method = data.get(GPSupportArgs.METHOD.value)
if not isinstance(method, str):
raise ValidationError("The 'method' type should be str.")
if method not in AcquisitionFunctionEnum.list_members():
raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." %
(AcquisitionFunctionEnum.list_members(), method))


class ParameterSchema(Schema):
"""Schema for parameter."""
number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG)
list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")

bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg)
type = fields.Str(error_messages=list_err_msg)
source = fields.Str(error_messages=str_err_msg)

@validates("bounds")
def check_bounds(self, bounds):
"""Check if bounds are valid."""
if len(bounds) != _BOUND_LEN:
raise ValidationError("Length of bounds should be %s." % _BOUND_LEN)
if bounds[1] <= bounds[0]:
raise ValidationError("The upper bound must be greater than lower bound. "
"The range is [lower_bound, upper_bound).")

@validates("type")
def check_type(self, type_in):
"""Check if type is valid."""
if type_in not in HyperParamType.list_members():
raise ValidationError("The type should be in %s." % HyperParamType.list_members())

@validates("source")
def check_source(self, source):
"""Check if source is valid."""
if source not in HyperParamSource.list_members():
raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source))

@validates_schema
def check_combination(self, data, **kwargs):
"""check the combination of parameters."""
bound_key = HyperParamKey.BOUND.value
choice_key = HyperParamKey.CHOICE.value
type_key = HyperParamKey.TYPE.value

# check bound and type
bounds = data.get(bound_key)
param_type = data.get(type_key)
if bounds is not None:
if param_type is None:
raise ValidationError("If %r is specified, the %r should be specified also." %
(HyperParamKey.BOUND.value, HyperParamKey.TYPE.value))
if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]):
raise ValidationError("No integer in 'bounds', please modify it.")

# check bound and choice
if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data):
raise ValidationError("Only one of [%r, %r] should be specified." %
(bound_key, choice_key))


class TargetSchema(Schema):
"""Schema for target."""
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")

group = fields.Str(error_messages=str_err_msg)
name = fields.Str(required=True, error_messages=str_err_msg)
goal = fields.Str(error_messages=str_err_msg)

@validates("group")
def check_group(self, group):
"""Check if bounds are valid."""
if group not in TargetGroup.list_members():
raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group))

@validates("goal")
def check_goal(self, goal):
"""Check if source is valid."""
if goal not in TargetGoal.list_members():
raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal))

@validates_schema
def check_combination(self, data, **kwargs):
"""check the combination of parameters."""
if TargetKey.GROUP.value not in data:
# if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
return
name = data.get(TargetKey.NAME.value)
group = data.get(TargetKey.GROUP.value)
if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members():
raise ValidationError({
TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group})


class OptimizerConfig(Schema):
"""Define the search model condition parameter schema."""
dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict")
str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str")

summary_base_dir = fields.Str(required=True, error_messages=str_err_msg)
command = fields.Str(required=True, error_messages=str_err_msg)
tuner = fields.Dict(required=True, error_messages=dict_err_msg)
target = fields.Dict(required=True, error_messages=dict_err_msg)
parameters = fields.Dict(required=True, error_messages=dict_err_msg)

@validates("tuner")
def check_tuner(self, data):
"""Check tuner."""
err = TunerSchema().validate(data)
if err:
raise ValidationError(err)

@validates("parameters")
def check_parameters(self, parameters):
"""Check parameters."""
for name, value in parameters.items():
err = ParameterSchema().validate(value)
if err:
raise ValidationError({name: err})

if HyperParamKey.SOURCE.value not in value:
# if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'.
continue
source = value.get(HyperParamKey.SOURCE.value)
if source == HyperParamSource.SYSTEM_DEFINED.value and \
name not in TunableSystemDefinedParams.list_members():
raise ValidationError({
name: {"source": "This param is not system defined. Current source is: %s." % source}})

@validates("target")
def check_target(self, target):
"""Check target."""
err = TargetSchema().validate(target)
if err:
raise ValidationError(err)

+ 70
- 24
mindinsight/optimizer/hyper_config.py View File

@@ -15,30 +15,73 @@
"""Hyper config.""" """Hyper config."""
import json import json
import os import os
from attrdict import AttrDict


from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
from mindinsight.optimizer.common.exceptions import HyperConfigError
from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT
from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError


_HYPER_CONFIG_LEN_LIMIT = 100000


class AttributeDict(dict):
"""A dict can be accessed by attribute."""
def __init__(self, d=None):
super().__init__()
if d is not None:
for k, v in d.items():
self[k] = v


class HyperConfig:
"""
Hyper config.
def __key(self, key):
"""Get key."""
return "" if key is None else key

def __setattr__(self, key, value):
"""Set attribute for object."""
self[self.__key(key)] = value

def __getattr__(self, key):
"""
Get attribute value according by attribute name.

Args:
key (str): attribute name.

Returns:
Any, attribute value.


Init hyper config:
>>> hyper_config = HyperConfig()
Raises:
AttributeError: If the key does not exists, will raise Exception.


Get suggest params:
>>> param_obj = hyper_config.params
>>> learning_rate = params.learning_rate
"""
value = self.get(self.__key(key))
if value is None:
raise AttributeError("The attribute %r is not exist." % key)
return value


Get summary dir:
>>> summary_dir = hyper_config.summary_dir
def __getitem__(self, key):
"""Get attribute value according by attribute name."""
value = super().get(self.__key(key))
if value is None:
raise AttributeError("The attribute %r is not exist." % key)
return value


Record by SummaryCollector:
>>> summary_cb = SummaryCollector(summary_dir)
def __setitem__(self, key, value):
"""Set attribute for object."""
return super().__setitem__(self.__key(key), value)


class HyperConfig:
"""
Hyper config.
1. Init HyperConfig.
2. Get suggested params and summary_dir.
3. Record by SummaryCollector with summary_dir.

Examples:
>>> hyper_config = HyperConfig()
>>> params = hyper_config.params
>>> learning_rate = params.learning_rate
>>> batch_size = params.batch_size

>>> summary_dir = hyper_config.summary_dir
>>> summary_cb = SummaryCollector(summary_dir)
""" """
def __init__(self): def __init__(self):
self._init_validate_hyper_config() self._init_validate_hyper_config()
@@ -47,10 +90,10 @@ class HyperConfig:
"""Init and validate hyper config.""" """Init and validate hyper config."""
hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
if hyper_config is None: if hyper_config is None:
raise HyperConfigError("Hyper config is not in system environment.")
if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT:
raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of "
"hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))
raise HyperConfigEnvError("Hyper config is not in system environment.")
if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT:
raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of "
"hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))


try: try:
hyper_config = json.loads(hyper_config) hyper_config = json.loads(hyper_config)
@@ -60,8 +103,7 @@ class HyperConfig:
raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc))


self._validate_hyper_config(hyper_config) self._validate_hyper_config(hyper_config)
self._summary_dir = hyper_config.get('summary_dir')
self._param_obj = AttrDict(hyper_config.get('params'))
self._hyper_config = hyper_config


def _validate_hyper_config(self, hyper_config): def _validate_hyper_config(self, hyper_config):
"""Validate hyper config.""" """Validate hyper config."""
@@ -86,9 +128,13 @@ class HyperConfig:
@property @property
def params(self): def params(self):
"""Get params.""" """Get params."""
return self._param_obj
return AttributeDict(self._hyper_config.get('params'))


@property @property
def summary_dir(self): def summary_dir(self):
"""Get train summary dir path.""" """Get train summary dir path."""
return self._summary_dir
return self._hyper_config.get('summary_dir')

@property
def custom_lineage_data(self):
return self._hyper_config.get('custom_lineage_data')

+ 33
- 38
mindinsight/optimizer/tuner.py View File

@@ -22,16 +22,16 @@ import yaml


from marshmallow import ValidationError from marshmallow import ValidationError


from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path
from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX
from mindinsight.lineagemgr.model import get_lineage_table, LineageTable
from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME
from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal
from mindinsight.optimizer.common.exceptions import OptimizerTerminateError
from mindinsight.optimizer.common.enums import TuneMethod
from mindinsight.optimizer.common.exceptions import OptimizerTerminateError, ConfigParamError
from mindinsight.optimizer.common.log import logger from mindinsight.optimizer.common.log import logger
from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig
from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner
from mindinsight.optimizer.utils.param_handler import organize_params_target from mindinsight.optimizer.utils.param_handler import organize_params_target
from mindinsight.optimizer.utils.utils import get_nested_message
from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError


_OK = 0 _OK = 0
@@ -51,7 +51,6 @@ class Tuner:
def __init__(self, config_path: str): def __init__(self, config_path: str):
self._config_info = self._validate_config(config_path) self._config_info = self._validate_config(config_path)
self._summary_base_dir = self._config_info.get('summary_base_dir') self._summary_base_dir = self._config_info.get('summary_base_dir')
self._data_manager = self._init_data_manager()
self._dir_prefix = 'train' self._dir_prefix = 'train'


def _validate_config(self, config_path): def _validate_config(self, config_path):
@@ -65,11 +64,17 @@ class Tuner:
except Exception as exc: except Exception as exc:
raise UnknownError("Detail: %s." % str(exc)) raise UnknownError("Detail: %s." % str(exc))


# need to add validation for config_info: command, summary_base_dir, target and params.
self._validate_config_schema(config_info)
config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir'))
self._make_summary_base_dir(config_info['summary_base_dir']) self._make_summary_base_dir(config_info['summary_base_dir'])
return config_info return config_info


def _validate_config_schema(self, config_info):
error = OptimizerConfig().validate(config_info)
if error:
err_msg = get_nested_message(error)
raise ConfigParamError(err_msg)

def _make_summary_base_dir(self, summary_base_dir): def _make_summary_base_dir(self, summary_base_dir):
"""Check and make summary_base_dir.""" """Check and make summary_base_dir."""
if not os.path.exists(summary_base_dir): if not os.path.exists(summary_base_dir):
@@ -82,13 +87,6 @@ class Tuner:
except OSError as exc: except OSError as exc:
raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc))


def _init_data_manager(self):
"""Initialize data_manager."""
data_manager = DataManager(summary_base_dir=self._summary_base_dir)
data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater())

return data_manager

def _normalize_path(self, param_name, path): def _normalize_path(self, param_name, path):
"""Normalize config path.""" """Normalize config path."""
path = os.path.realpath(path) path = os.path.realpath(path)
@@ -104,10 +102,8 @@ class Tuner:


def _update_from_lineage(self): def _update_from_lineage(self):
"""Update lineage from lineagemgr.""" """Update lineage from lineagemgr."""
self._data_manager.start_load_data(reload_interval=0).join()

try: try:
lineage_table = get_lineage_table(self._data_manager)
lineage_table = get_lineage_table(summary_base_dir=self._summary_base_dir)
except MindInsightException as err: except MindInsightException as err:
logger.info("Can not query lineage. Detail: %s", str(err)) logger.info("Can not query lineage. Detail: %s", str(err))
lineage_table = None lineage_table = None
@@ -122,12 +118,14 @@ class Tuner:
tuner = self._config_info.get('tuner') tuner = self._config_info.get('tuner')
for _ in range(max_expr_times): for _ in range(max_expr_times):
self._update_from_lineage() self._update_from_lineage()
suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name"))
suggestion, user_defined_info = self._suggest(self._lineage_table, params_info, target_info, tuner)


hyper_config = { hyper_config = {
'params': suggestion, 'params': suggestion,
'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}')
'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}'),
'custom_lineage_data': user_defined_info
} }
logger.info("Suggest values are: %s.", suggestion)
os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config)
s = subprocess.Popen(shlex.split(command)) s = subprocess.Popen(shlex.split(command))
s.wait() s.wait()
@@ -135,29 +133,26 @@ class Tuner:
logger.error("An error occurred during execution, the auto tuning will be terminated.") logger.error("An error occurred during execution, the auto tuning will be terminated.")
raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.")


def _get_tuner(self, tune_method=TuneMethod.GP.value):
def _get_tuner(self, tuner):
"""Get tuner.""" """Get tuner."""
if tune_method.lower() not in TuneMethod.list_members():
if tuner is None:
return GPBaseTuner()

tuner_name = tuner.get("name").lower()
if tuner_name not in TuneMethod.list_members():
raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members())


args = tuner.get("args")
if args is not None and args.get("method") is not None:
return GPBaseTuner(args.get("method"))

# Only support gaussian process regressor currently. # Only support gaussian process regressor currently.
return GPBaseTuner() return GPBaseTuner()


def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method):
def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, tuner):
"""Get suggestions for targets.""" """Get suggestions for targets."""
tuner = self._get_tuner(method)
target_name = target_info[TargetKey.NAME.value]
if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric':
target_name = METRIC_PREFIX + target_name
param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name)

if not param_matrix.empty:
suggestion = tuner.suggest([], [], params_info)
else:
target_column = target_matrix[target_name].reshape((-1, 1))
if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value:
target_column = -target_column

suggestion = tuner.suggest(param_matrix, target_column, params_info)

return suggestion
tuner = self._get_tuner(tuner)
param_matrix, target_column = organize_params_target(lineage_table, params_info, target_info)
suggestion, user_defined_info = tuner.suggest(param_matrix, target_column, params_info)

return suggestion, user_defined_info

+ 10
- 5
mindinsight/optimizer/tuners/gp_tuner.py View File

@@ -22,10 +22,11 @@ from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern from sklearn.gaussian_process.kernels import Matern


from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey
from mindinsight.optimizer.common.log import logger
from mindinsight.optimizer.tuners.base_tuner import BaseTuner
from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type
from mindinsight.optimizer.utils.transformer import Transformer from mindinsight.optimizer.utils.transformer import Transformer
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from mindinsight.optimizer.tuners.base_tuner import BaseTuner




class AcquisitionFunction: class AcquisitionFunction:
@@ -141,8 +142,10 @@ class GPBaseTuner(BaseTuner):


x_seeds = generate_arrays(params_info, n_iter) x_seeds = generate_arrays(params_info, n_iter)
for x_try in x_seeds: for x_try in x_seeds:
res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max),
x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max),
x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B")


if not res.success: if not res.success:
continue continue
@@ -164,6 +167,8 @@ class GPBaseTuner(BaseTuner):


min_lineage_rows = 2 min_lineage_rows = 2
if not np.array(params).any() or params.shape[0] < min_lineage_rows: if not np.array(params).any() or params.shape[0] < min_lineage_rows:
logger.info("Without valid histories or the rows of lineages < %s, "
"parameters will be recommended randomly.", min_lineage_rows)
suggestion = generate_arrays(params_info) suggestion = generate_arrays(params_info)
else: else:
self._gp.fit(params, target) self._gp.fit(params, target)
@@ -174,5 +179,5 @@ class GPBaseTuner(BaseTuner):
params_info=params_info params_info=params_info
) )


suggestion = Transformer.transform_list_to_dict(params_info, suggestion)
return suggestion
suggestion, user_defined_info = Transformer.transform_list_to_dict(params_info, suggestion)
return suggestion, user_defined_info

+ 38
- 7
mindinsight/optimizer/utils/param_handler.py View File

@@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
"""Utils for params.""" """Utils for params."""
import numpy as np import numpy as np
from mindinsight.lineagemgr.model import LineageTable
from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType
from mindinsight.lineagemgr.model import LineageTable, USER_DEFINED_PREFIX, METRIC_PREFIX
from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType, HyperParamSource, TargetKey, \
TargetGoal, TunableSystemDefinedParams, TargetGroup, SystemDefinedTargets
from mindinsight.optimizer.common.log import logger from mindinsight.optimizer.common.log import logger




@@ -54,7 +55,6 @@ def match_value_type(array, params_info: dict):
array_new = [] array_new = []
index = 0 index = 0
for _, param_info in params_info.items(): for _, param_info in params_info.items():
param_type = param_info[HyperParamKey.TYPE.value]
value = array[index] value = array[index]
if HyperParamKey.BOUND.value in param_info: if HyperParamKey.BOUND.value in param_info:
bound = param_info[HyperParamKey.BOUND.value] bound = param_info[HyperParamKey.BOUND.value]
@@ -64,7 +64,7 @@ def match_value_type(array, params_info: dict):
choices = param_info[HyperParamKey.CHOICE.value] choices = param_info[HyperParamKey.CHOICE.value]
nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) nearest_index = int(np.argmin(np.fabs(np.array(choices) - value)))
value = choices[nearest_index] value = choices[nearest_index]
if param_type == HyperParamType.INT.value:
if param_info.get(HyperParamKey.TYPE.value) == HyperParamType.INT.value:
value = int(value) value = int(value)
if HyperParamKey.DECIMAL.value in param_info: if HyperParamKey.DECIMAL.value in param_info:
value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value])
@@ -73,20 +73,51 @@ def match_value_type(array, params_info: dict):
return array_new return array_new




def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name):
def organize_params_target(lineage_table: LineageTable, params_info: dict, target_info):
"""Organize params and target.""" """Organize params and target."""
empty_result = np.array([]) empty_result = np.array([])
if lineage_table is None: if lineage_table is None:
return empty_result, empty_result return empty_result, empty_result


param_keys = list(params_info.keys())
param_keys = []
for param_key, param_info in params_info.items():
# It will be a user_defined param:
# 1. if 'source' is specified as 'user_defined'
# 2. if 'source' is not specified and the param is not a system_defined key
source = param_info.get(HyperParamKey.SOURCE.value)
prefix = _get_prefix(param_key, source, HyperParamSource.USER_DEFINED.value,
USER_DEFINED_PREFIX, TunableSystemDefinedParams.list_members())
param_key = f'{prefix}{param_key}'
if prefix == USER_DEFINED_PREFIX:
param_info[HyperParamKey.SOURCE.value] = HyperParamSource.USER_DEFINED.value
else:
param_info[HyperParamKey.SOURCE.value] = HyperParamSource.SYSTEM_DEFINED.value

param_keys.append(param_key)


target_name = target_info[TargetKey.NAME.value]
group = target_info.get(TargetKey.GROUP.value)
prefix = _get_prefix(target_name, group, TargetGroup.METRIC.value,
METRIC_PREFIX, SystemDefinedTargets.list_members())
target_name = prefix + target_name
lineage_df = lineage_table.dataframe_data lineage_df = lineage_table.dataframe_data
try: try:
lineage_df = lineage_df[param_keys + [target_name]] lineage_df = lineage_df[param_keys + [target_name]]
lineage_df = lineage_df.dropna(axis=0, how='any') lineage_df = lineage_df.dropna(axis=0, how='any')
return lineage_df[param_keys], lineage_df[target_name]

target_column = np.array(lineage_df[target_name])
if TargetKey.GOAL.value in target_info and \
target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value:
target_column = -target_column

return np.array(lineage_df[param_keys]), target_column
except KeyError as exc: except KeyError as exc:
logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." logger.warning("Some keys not exist in specified params or target. It will suggest params randomly."
"Detail: %s.", str(exc)) "Detail: %s.", str(exc))
return empty_result, empty_result return empty_result, empty_result


def _get_prefix(name, field, other_defined_field, other_defined_prefix, system_defined_fields):
if (field == other_defined_field) or (field is None and name not in system_defined_fields):
return other_defined_prefix
return ''

+ 8
- 2
mindinsight/optimizer/utils/transformer.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Transformer.""" """Transformer."""
from mindinsight.optimizer.utils.param_handler import match_value_type from mindinsight.optimizer.utils.param_handler import match_value_type
from mindinsight.optimizer.common.enums import HyperParamSource, HyperParamKey




class Transformer: class Transformer:
@@ -23,7 +24,12 @@ class Transformer:
"""Transform from tuner.""" """Transform from tuner."""
suggest_list = match_value_type(suggest_list, params_info) suggest_list = match_value_type(suggest_list, params_info)
param_dict = {} param_dict = {}
user_defined_info = {}
for index, param_name in enumerate(params_info): for index, param_name in enumerate(params_info):
param_dict.update({param_name: suggest_list[index]})
param_item = {param_name: suggest_list[index]}
param_dict.update(param_item)
source = params_info.get(param_name).get(HyperParamKey.SOURCE.value)
if source is not None and source == HyperParamSource.USER_DEFINED.value:
user_defined_info.update(param_item)


return param_dict
return param_dict, user_defined_info

+ 15
- 0
mindinsight/optimizer/utils/utils.py View File

@@ -68,3 +68,18 @@ def is_simple_numpy_number(dtype):
return True return True


return False return False


def get_nested_message(info: dict, out_err_msg=""):
"""Get error message from the error dict generated by schema validation."""
if not isinstance(info, dict):
if isinstance(info, list):
info = info[0]
return f'Error in {out_err_msg}: {info}'
for key in info:
if isinstance(key, str) and key != '_schema':
if out_err_msg:
out_err_msg = f'{out_err_msg}.{key}'
else:
out_err_msg = key
return get_nested_message(info[key], out_err_msg)

+ 2
- 0
mindinsight/utils/constant.py View File

@@ -98,3 +98,5 @@ class OptimizerErrors(Enum):
CORRELATION_NAN = 2 CORRELATION_NAN = 2
HYPER_CONFIG_ERROR = 3 HYPER_CONFIG_ERROR = 3
OPTIMIZER_TERMINATE = 4 OPTIMIZER_TERMINATE = 4
CONFIG_PARAM_ERROR = 5
HYPER_CONFIG_ENV_ERROR = 6

+ 2
- 2
requirements.txt View File

@@ -10,9 +10,9 @@ marshmallow>=2.19.2
numpy>=1.17.0 numpy>=1.17.0
protobuf>=3.8.0 protobuf>=3.8.0
psutil>=5.6.1 psutil>=5.6.1
pyyaml>=5.3
pyyaml>=5.3.1
scipy>=1.3.3 scipy>=1.3.3
scikit-learn>=0.23.1
scikit-learn>=0.21.2
six>=1.12.0 six>=1.12.0
Werkzeug>=1.0.0 Werkzeug>=1.0.0
pandas>=1.0.4 pandas>=1.0.4


+ 1
- 0
setup.py View File

@@ -208,6 +208,7 @@ if __name__ == '__main__':
'mindinsight=mindinsight.utils.command:main', 'mindinsight=mindinsight.utils.command:main',
'mindconverter=mindinsight.mindconverter.cli:cli_entry', 'mindconverter=mindinsight.mindconverter.cli:cli_entry',
'mindwizard=mindinsight.wizard.cli:cli_entry', 'mindwizard=mindinsight.wizard.cli:cli_entry',
'mindoptimizer=mindinsight.optimizer.cli:cli_entry',
], ],
}, },
python_requires='>=3.7', python_requires='>=3.7',


+ 3
- 3
tests/st/func/lineagemgr/collection/model/test_model_lineage.py View File

@@ -98,14 +98,14 @@ class TestModelLineage(TestCase):
train_callback.end(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context))


LINEAGE_DATA_MANAGER.start_load_data().join() LINEAGE_DATA_MANAGER.start_load_data().join()
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition)
assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10
run_context = self.run_context run_context = self.run_context
run_context['epoch_num'] = 14 run_context['epoch_num'] = 14
train_callback.end(RunContext(run_context)) train_callback.end(RunContext(run_context))


LINEAGE_DATA_MANAGER.start_load_data().join() LINEAGE_DATA_MANAGER.start_load_data().join()
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition)
assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14


@pytest.mark.scene_eval(3) @pytest.mark.scene_eval(3)
@@ -198,7 +198,7 @@ class TestModelLineage(TestCase):
train_callback.end(RunContext(run_context_customized)) train_callback.end(RunContext(run_context_customized))


LINEAGE_DATA_MANAGER.start_load_data().join() LINEAGE_DATA_MANAGER.start_load_data().join()
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition)
assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \
== 'SoftmaxCrossEntropyWithLogits' == 'SoftmaxCrossEntropyWithLogits'
assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet'


+ 21
- 9
tests/st/func/lineagemgr/test_model.py View File

@@ -190,7 +190,7 @@ class TestModelApi(TestCase):
search_condition = { search_condition = {
'sorted_name': 'summary_dir' 'sorted_name': 'summary_dir'
} }
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
@@ -228,7 +228,7 @@ class TestModelApi(TestCase):
], ],
'count': 2 'count': 2
} }
partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition)
partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
@@ -266,7 +266,7 @@ class TestModelApi(TestCase):
], ],
'count': 2 'count': 2
} }
partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition)
partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
@@ -295,7 +295,7 @@ class TestModelApi(TestCase):
], ],
'count': 3 'count': 3
} }
partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1)
partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1)
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')): for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
@@ -314,7 +314,7 @@ class TestModelApi(TestCase):
'object': [], 'object': [],
'count': 0 'count': 0
} }
partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2)
partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2)
assert expect_result == partial_res2 assert expect_result == partial_res2


@pytest.mark.level0 @pytest.mark.level0
@@ -335,7 +335,7 @@ class TestModelApi(TestCase):
'eq': self._empty_train_id 'eq': self._empty_train_id
} }
} }
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition)
assert expect_result == res assert expect_result == res


@pytest.mark.level0 @pytest.mark.level0
@@ -366,7 +366,7 @@ class TestModelApi(TestCase):
], ],
'count': 1 'count': 1
} }
res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition)
res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition)
assert expect_result == res assert expect_result == res


@pytest.mark.level0 @pytest.mark.level0
@@ -386,6 +386,7 @@ class TestModelApi(TestCase):
'The search_condition element summary_dir should be dict.', 'The search_condition element summary_dir should be dict.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -398,6 +399,7 @@ class TestModelApi(TestCase):
'The sorted_name must exist when sorted_type exists.', 'The sorted_name must exist when sorted_type exists.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -408,6 +410,7 @@ class TestModelApi(TestCase):
'Invalid search_condition type, it should be dict.', 'Invalid search_condition type, it should be dict.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -420,6 +423,7 @@ class TestModelApi(TestCase):
'The limit must be int.', 'The limit must be int.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -440,6 +444,7 @@ class TestModelApi(TestCase):
'The offset must be int.', 'The offset must be int.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -454,6 +459,7 @@ class TestModelApi(TestCase):
'The search attribute not supported.', 'The search attribute not supported.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -475,6 +481,7 @@ class TestModelApi(TestCase):
'The sorted_type must be ascending or descending', 'The sorted_type must be ascending or descending',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -489,6 +496,7 @@ class TestModelApi(TestCase):
'The compare condition should be in', 'The compare condition should be in',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -503,6 +511,7 @@ class TestModelApi(TestCase):
'The parameter metric/accuracy is invalid.', 'The parameter metric/accuracy is invalid.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -526,7 +535,7 @@ class TestModelApi(TestCase):
'object': [], 'object': [],
'count': 0 'count': 0
} }
partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1)
partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1)
assert expect_result == partial_res1 assert expect_result == partial_res1


# the (offset + 1) * limit > count # the (offset + 1) * limit > count
@@ -542,7 +551,7 @@ class TestModelApi(TestCase):
'object': [], 'object': [],
'count': 2 'count': 2
} }
partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2)
partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2)
assert expect_result == partial_res2 assert expect_result == partial_res2


@pytest.mark.level0 @pytest.mark.level0
@@ -566,6 +575,7 @@ class TestModelApi(TestCase):
f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.', f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -589,6 +599,7 @@ class TestModelApi(TestCase):
"The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.",
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )


@@ -610,6 +621,7 @@ class TestModelApi(TestCase):
'The sorted_name must be in', 'The sorted_name must be in',
filter_summary_lineage, filter_summary_lineage,
LINEAGE_DATA_MANAGER, LINEAGE_DATA_MANAGER,
None,
search_condition search_condition
) )




+ 1
- 1
tests/ut/lineagemgr/common/validator/test_validate.py View File

@@ -23,7 +23,7 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError




class TestValidateSearchModelCondition(TestCase): class TestValidateSearchModelCondition(TestCase):
"""Test the mothod of validate_search_model_condition."""
"""Test the method of validate_search_model_condition."""
def test_validate_search_model_condition_param_type_error(self): def test_validate_search_model_condition_param_type_error(self):
"""Test the method of validate_search_model_condition with LineageParamTypeError.""" """Test the method of validate_search_model_condition with LineageParamTypeError."""
condition = { condition = {


+ 15
- 0
tests/ut/optimizer/common/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test common module."""

+ 15
- 0
tests/ut/optimizer/common/validator/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test validator."""

+ 161
- 0
tests/ut/optimizer/common/validator/test_optimizer_config.py View File

@@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test optimizer config schema."""
from copy import deepcopy

from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig
from mindinsight.optimizer.common.enums import TargetGroup, HyperParamSource

_BASE_DATA = {
'command': 'python ./test.py',
'summary_base_dir': './demo_lineage_r0.3',
'tuner': {
'name': 'gp',
'args': {
'method': 'ucb'
}
},
'target': {
'group': 'metric',
'name': 'Accuracy',
'goal': 'maximize'
},
'parameters': {
'learning_rate': {
'bounds': [0.0001, 0.001],
'type': 'float'
},
'batch_size': {
'choice': [32, 64, 128, 256],
'type': 'int'
},
'decay_step': {
'choice': [20],
'type': 'int'
}
}
}


class TestOptimizerConfig:
"""Test the method of validate_search_model_condition."""
_config_dict = dict(_BASE_DATA)

def test_config_dict_with_wrong_type(self):
"""Test config dict with wrong type."""
config_dict = deepcopy(self._config_dict)
init_list = ['a']
init_str = 'a'

config_dict['command'] = init_list
config_dict['summary_base_dir'] = init_list
config_dict['target']['name'] = init_list
config_dict['target']['goal'] = init_list
config_dict['parameters']['learning_rate']['bounds'] = init_str
config_dict['parameters']['learning_rate']['choice'] = init_str
expected_err = {
'command': ["Value type should be 'str'."],
'parameters': {
'learning_rate': {
'bounds': ["Value type should be 'list'."],
'choice': ["Value type should be 'list'."]
}
},
'summary_base_dir': ["Value type should be 'str'."],
'target': {
'name': ["Value type should be 'str'."],
'goal': ["Value type should be 'str'."]
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_config_dict_with_wrong_value(self):
"""Test config dict with wrong value."""
config_dict = deepcopy(self._config_dict)
init_list = ['a']
init_str = 'a'

config_dict['target']['group'] = init_str
config_dict['target']['goal'] = init_str
config_dict['tuner']['name'] = init_str
config_dict['parameters']['learning_rate']['bounds'] = init_list
config_dict['parameters']['learning_rate']['choice'] = init_list
config_dict['parameters']['learning_rate']['type'] = init_str
expected_err = {
'parameters': {
'learning_rate': {
'bounds': {
0: ['Value(s) should be integer or float.']
},
'choice': {
0: ['Value(s) should be integer or float.']
},
'type': ["The type should be in ['int', 'float']."]
}
},
'target': {
'goal': ["Value should be in ['maximize', 'minimize']. Current value is a."],
'group': ["Value should be in ['system_defined', 'metric']. Current value is a."]},
'tuner': {
'name': ['Must be one of: gp.']
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_target_combination(self):
"""Test target combination."""
config_dict = deepcopy(self._config_dict)

config_dict['target']['group'] = TargetGroup.SYSTEM_DEFINED.value
config_dict['target']['name'] = 'a'
expected_err = {
'target': {
'group': 'This target is not system defined. Current group is: system_defined.'
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_parameters_combination1(self):
"""Test parameters combination."""
config_dict = deepcopy(self._config_dict)

config_dict['parameters']['decay_step']['source'] = HyperParamSource.SYSTEM_DEFINED.value
expected_err = {
'parameters': {
'decay_step': {
'source': 'This param is not system defined. Current source is: system_defined.'
}
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

def test_parameters_combination2(self):
"""Test parameters combination."""
config_dict = deepcopy(self._config_dict)

config_dict['parameters']['decay_step']['bounds'] = [1, 40]
expected_err = {
'parameters': {
'decay_step': {
'_schema': ["Only one of ['bounds', 'choice'] should be specified."]
}
}
}
err = OptimizerConfig().validate(config_dict)
assert expected_err == err

+ 15
- 0
tests/ut/optimizer/utils/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test utils."""

tests/ut/optimizer/test_utils.py → tests/ut/optimizer/utils/test_utils.py View File


Loading…
Cancel
Save