From 0f37c8135e27cd5b5d5ca09cb0c9af7f615f03ba Mon Sep 17 00:00:00 2001 From: luopengting Date: Wed, 21 Oct 2020 17:38:33 +0800 Subject: [PATCH] add cli, validation and st: 1. add cli entry 2. add more validation for hyper config 3. add validator for config dict 4. add st for optimizer config validator 5. add custom lineage for hyper config --- .../backend/optimizer/optimizer_api.py | 2 +- mindinsight/lineagemgr/lineage_parser.py | 39 +++- mindinsight/lineagemgr/model.py | 24 ++- mindinsight/optimizer/cli.py | 101 +++++++++ mindinsight/optimizer/common/constants.py | 1 + mindinsight/optimizer/common/enums.py | 27 +++ mindinsight/optimizer/common/exceptions.py | 16 ++ .../optimizer/common/validator/__init__.py | 15 ++ .../common/validator/optimizer_config.py | 202 ++++++++++++++++++ mindinsight/optimizer/hyper_config.py | 94 +++++--- mindinsight/optimizer/tuner.py | 71 +++--- mindinsight/optimizer/tuners/gp_tuner.py | 15 +- mindinsight/optimizer/utils/param_handler.py | 45 +++- mindinsight/optimizer/utils/transformer.py | 10 +- mindinsight/optimizer/utils/utils.py | 15 ++ mindinsight/utils/constant.py | 2 + requirements.txt | 4 +- setup.py | 1 + .../collection/model/test_model_lineage.py | 6 +- tests/st/func/lineagemgr/test_model.py | 30 ++- .../common/validator/test_validate.py | 2 +- tests/ut/optimizer/common/__init__.py | 15 ++ .../ut/optimizer/common/validator/__init__.py | 15 ++ .../common/validator/test_optimizer_config.py | 161 ++++++++++++++ tests/ut/optimizer/utils/__init__.py | 15 ++ tests/ut/optimizer/{ => utils}/test_utils.py | 0 26 files changed, 829 insertions(+), 99 deletions(-) create mode 100644 mindinsight/optimizer/cli.py create mode 100644 mindinsight/optimizer/common/validator/__init__.py create mode 100644 mindinsight/optimizer/common/validator/optimizer_config.py create mode 100644 tests/ut/optimizer/common/__init__.py create mode 100644 tests/ut/optimizer/common/validator/__init__.py create mode 100644 tests/ut/optimizer/common/validator/test_optimizer_config.py create mode 100644 tests/ut/optimizer/utils/__init__.py rename tests/ut/optimizer/{ => utils}/test_utils.py (100%) diff --git a/mindinsight/backend/optimizer/optimizer_api.py b/mindinsight/backend/optimizer/optimizer_api.py index fecd13ea..78099ef2 100644 --- a/mindinsight/backend/optimizer/optimizer_api.py +++ b/mindinsight/backend/optimizer/optimizer_api.py @@ -43,7 +43,7 @@ def get_optimize_targets(): def _get_optimize_targets(data_manager, search_condition=None): """Get optimize targets.""" - table = get_lineage_table(data_manager, search_condition) + table = get_lineage_table(data_manager=data_manager, search_condition=search_condition) target_summaries = [] for target in table.target_names: diff --git a/mindinsight/lineagemgr/lineage_parser.py b/mindinsight/lineagemgr/lineage_parser.py index ed895ce5..647a0607 100644 --- a/mindinsight/lineagemgr/lineage_parser.py +++ b/mindinsight/lineagemgr/lineage_parser.py @@ -15,6 +15,7 @@ """This file is used to parse lineage info.""" import os +from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \ LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \ MindInsightException @@ -177,13 +178,49 @@ class LineageParser: class LineageOrganizer: """Lineage organizer.""" - def __init__(self, data_manager): + def __init__(self, data_manager=None, summary_base_dir=None): self._data_manager = data_manager + self._summary_base_dir = summary_base_dir + self._check_params() self._super_lineage_objs = {} self._organize_from_cache() + self._organize_from_disk() + + def _check_params(self): + """Check params.""" + if self._data_manager is not None and self._summary_base_dir is not None: + self._summary_base_dir = None + + def _organize_from_disk(self): + """Organize lineage objs from disk.""" + if self._summary_base_dir is None: + return + summary_watcher = SummaryWatcher() + relative_dirs = summary_watcher.list_summary_directories( + summary_base_dir=self._summary_base_dir + ) + + no_lineage_count = 0 + for item in relative_dirs: + relative_dir = item.get('relative_path') + update_time = item.get('update_time') + abs_summary_dir = os.path.realpath(os.path.join(self._summary_base_dir, relative_dir)) + + try: + lineage_parser = LineageParser(relative_dir, abs_summary_dir, update_time) + super_lineage_obj = lineage_parser.super_lineage_obj + if super_lineage_obj is not None: + self._super_lineage_objs.update({abs_summary_dir: super_lineage_obj}) + except LineageFileNotFoundError: + no_lineage_count += 1 + + if no_lineage_count == len(relative_dirs): + logger.info('There is no summary log file under summary_base_dir.') def _organize_from_cache(self): """Organize lineage objs from cache.""" + if self._data_manager is None: + return brief_cache = self._data_manager.get_brief_cache() cache_items = brief_cache.cache_items for relative_dir, cache_train_job in cache_items.items(): diff --git a/mindinsight/lineagemgr/model.py b/mindinsight/lineagemgr/model.py index 7f34b070..cdacc4d2 100644 --- a/mindinsight/lineagemgr/model.py +++ b/mindinsight/lineagemgr/model.py @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySumm from mindinsight.lineagemgr.common.log import logger as log from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition +from mindinsight.lineagemgr.common.validator.validate_path import validate_and_normalize_path from mindinsight.lineagemgr.lineage_parser import LineageOrganizer from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.optimizer.common.enums import ReasonCode @@ -33,7 +34,7 @@ USER_DEFINED_PREFIX = "[U]" USER_DEFINED_INFO_LIMIT = 100 -def filter_summary_lineage(data_manager, search_condition=None): +def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None): """ Filter summary lineage from data_manager or parsing from summaries. @@ -43,8 +44,18 @@ def filter_summary_lineage(data_manager, search_condition=None): Args: data_manager (DataManager): Data manager defined as mindinsight.datavisual.data_transform.data_manager.DataManager + summary_base_dir (str): The summary base directory. It contains summary + directories generated by training. search_condition (dict): The search condition. """ + if data_manager is None and summary_base_dir is None: + raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.") + + if data_manager is None: + summary_base_dir = validate_and_normalize_path(summary_base_dir, 'summary_base_dir') + else: + summary_base_dir = data_manager.summary_base_dir + search_condition = {} if search_condition is None else search_condition try: @@ -56,7 +67,7 @@ def filter_summary_lineage(data_manager, search_condition=None): raise LineageSearchConditionParamError(str(error.message)) try: - lineage_objects = LineageOrganizer(data_manager).super_lineage_objs + lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition) except LineageSummaryParseException: result = {'object': [], 'count': 0} @@ -68,12 +79,13 @@ def filter_summary_lineage(data_manager, search_condition=None): return result -def get_flattened_lineage(data_manager, search_condition=None): +def get_flattened_lineage(data_manager=None, summary_base_dir=None, search_condition=None): """ Get lineage data in a table from data manager. Args: data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. + summary_base_dir (str): The base directory for train jobs. search_condition (dict): The search condition. Returns: @@ -81,7 +93,7 @@ def get_flattened_lineage(data_manager, search_condition=None): """ flatten_dict, user_count = {'train_id': []}, 0 - lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", []) + lineages = filter_summary_lineage(data_manager, summary_base_dir, search_condition).get("object", []) for index, lineage in enumerate(lineages): flatten_dict['train_id'].append(lineage.get("summary_dir")) for key, val in _flatten_lineage(lineage.get('model_lineage', {})): @@ -222,9 +234,9 @@ class LineageTable: return self._drop_columns_info -def get_lineage_table(data_manager, search_condition=None): +def get_lineage_table(data_manager=None, summary_base_dir=None, search_condition=None): """Get lineage table from data_manager.""" - lineage_table = get_flattened_lineage(data_manager, search_condition) + lineage_table = get_flattened_lineage(data_manager, summary_base_dir, search_condition) lineage_table = LineageTable(pd.DataFrame(lineage_table)) return lineage_table diff --git a/mindinsight/optimizer/cli.py b/mindinsight/optimizer/cli.py new file mode 100644 index 00000000..0ad3206d --- /dev/null +++ b/mindinsight/optimizer/cli.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Command module.""" +import argparse +import os +import sys + +import mindinsight +from mindinsight.optimizer.tuner import Tuner + + +class ConfigAction(argparse.Action): + """Summary base dir action class definition.""" + + def __call__(self, parser_in, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser_in (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Option string for specific argument name. + """ + config_path = os.path.realpath(values) + if not os.path.exists(config_path): + parser_in.error(f'{option_string} {config_path} not exists.') + + setattr(namespace, self.dest, config_path) + + +class IterAction(argparse.Action): + """Summary base dir action class definition.""" + + def __call__(self, parser_in, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser_in (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Option string for specific argument name. + """ + iter_times = values + if iter_times <= 0: + parser_in.error(f'{option_string} {iter_times} should be a positive integer.') + + setattr(namespace, self.dest, iter_times) + + +parser = argparse.ArgumentParser( + prog='mindoptimizer', + description='MindOptimizer CLI entry point (version: {})'.format(mindinsight.__version__)) + +parser.add_argument( + '--version', + action='version', + version='%(prog)s ({})'.format(mindinsight.__version__)) + +parser.add_argument( + '--config', + type=str, + action=ConfigAction, + required=True, + default=os.path.join(os.getcwd(), 'output'), + help="Specify path for config file." +) + +parser.add_argument( + '--iter', + type=int, + action=IterAction, + default=1, + help="Optional, specify run times for the command in config file." +) + + +def cli_entry(): + """Cli entry.""" + argv = sys.argv[1:] + if not argv: + argv = ['-h'] + args = parser.parse_args(argv) + else: + args = parser.parse_args() + + tuner = Tuner(args.config) + tuner.optimize(max_expr_times=args.iter) diff --git a/mindinsight/optimizer/common/constants.py b/mindinsight/optimizer/common/constants.py index 2241626d..622a0d1c 100644 --- a/mindinsight/optimizer/common/constants.py +++ b/mindinsight/optimizer/common/constants.py @@ -14,3 +14,4 @@ # ============================================================================ """Common constants for optimizer.""" HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" +HYPER_CONFIG_LEN_LIMIT = 100000 diff --git a/mindinsight/optimizer/common/enums.py b/mindinsight/optimizer/common/enums.py index adb5216a..aea92fda 100644 --- a/mindinsight/optimizer/common/enums.py +++ b/mindinsight/optimizer/common/enums.py @@ -48,12 +48,17 @@ class TuneMethod(BaseEnum): GP = 'gp' +class GPSupportArgs(BaseEnum): + METHOD = 'method' + + class HyperParamKey(BaseEnum): """Config keys for hyper parameters.""" BOUND = 'bounds' CHOICE = 'choice' DECIMAL = 'decimal' TYPE = 'type' + SOURCE = 'source' class HyperParamType(BaseEnum): @@ -73,3 +78,25 @@ class TargetGoal(BaseEnum): """Goal for target.""" MAXIMUM = 'maximize' MINIMUM = 'minimize' + + +class HyperParamSource(BaseEnum): + SYSTEM_DEFINED = 'system_defined' + USER_DEFINED = 'user_defined' + + +class TargetGroup(BaseEnum): + SYSTEM_DEFINED = 'system_defined' + METRIC = 'metric' + + +class TunableSystemDefinedParams(BaseEnum): + """Tunable metadata keys of lineage collection.""" + BATCH_SIZE = 'batch_size' + EPOCH = 'epoch' + LEARNING_RATE = 'learning_rate' + + +class SystemDefinedTargets(BaseEnum): + """System defined targets""" + LOSS = 'loss' diff --git a/mindinsight/optimizer/common/exceptions.py b/mindinsight/optimizer/common/exceptions.py index bf3606a7..e75fb830 100644 --- a/mindinsight/optimizer/common/exceptions.py +++ b/mindinsight/optimizer/common/exceptions.py @@ -48,3 +48,19 @@ class OptimizerTerminateError(MindInsightException): super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, error_msg, http_code=400) + + +class ConfigParamError(MindInsightException): + """Hyper config error.""" + def __init__(self, error_msg="Invalid parameter."): + super(ConfigParamError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, + error_msg, + http_code=400) + + +class HyperConfigEnvError(MindInsightException): + """Hyper config error.""" + def __init__(self, error_msg="Hyper config is not correct."): + super(HyperConfigEnvError, self).__init__(OptimizerErrors.HYPER_CONFIG_ENV_ERROR, + error_msg, + http_code=400) diff --git a/mindinsight/optimizer/common/validator/__init__.py b/mindinsight/optimizer/common/validator/__init__.py new file mode 100644 index 00000000..ec8ca645 --- /dev/null +++ b/mindinsight/optimizer/common/validator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Common validator modules for optimizer.""" diff --git a/mindinsight/optimizer/common/validator/optimizer_config.py b/mindinsight/optimizer/common/validator/optimizer_config.py new file mode 100644 index 00000000..7d092edc --- /dev/null +++ b/mindinsight/optimizer/common/validator/optimizer_config.py @@ -0,0 +1,202 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Validator for optimizer config.""" +import math + +from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema + +from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \ + HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \ + HyperParamKey, SystemDefinedTargets + +_BOUND_LEN = 2 +_NUMBER_ERR_MSG = "Value(s) should be integer or float." +_TYPE_ERR_MSG = "Value type should be %r." +_VALUE_ERR_MSG = "Value should be in %s. Current value is %s." + + +def _generate_schema_err_msg(err_msg, *args): + """Organize error messages.""" + if args: + err_msg = err_msg % args + return {"invalid": err_msg} + + +def include_integer(low, high): + """Check if the range [low, high) includes integer.""" + def _in_range(num, low, high): + """check if num in [low, high)""" + return low <= num < high + + if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high): + return True + return False + + +class TunerSchema(Schema): + """Schema for tuner.""" + dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") + + name = fields.Str(required=True, + validate=validate.OneOf(TuneMethod.list_members()), + error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members())) + args = fields.Dict(error_messages=dict_err_msg) + + @validates("args") + def check_args(self, data): + """Check args for tuner.""" + data_keys = list(data.keys()) + support_args = GPSupportArgs.list_members() + if not set(data_keys).issubset(set(support_args)): + raise ValidationError("Only support setting %s for tuner. " + "Current key(s): %s." % (support_args, data_keys)) + + method = data.get(GPSupportArgs.METHOD.value) + if not isinstance(method, str): + raise ValidationError("The 'method' type should be str.") + if method not in AcquisitionFunctionEnum.list_members(): + raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." % + (AcquisitionFunctionEnum.list_members(), method)) + + +class ParameterSchema(Schema): + """Schema for parameter.""" + number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG) + list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list") + str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") + + bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) + choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) + type = fields.Str(error_messages=list_err_msg) + source = fields.Str(error_messages=str_err_msg) + + @validates("bounds") + def check_bounds(self, bounds): + """Check if bounds are valid.""" + if len(bounds) != _BOUND_LEN: + raise ValidationError("Length of bounds should be %s." % _BOUND_LEN) + if bounds[1] <= bounds[0]: + raise ValidationError("The upper bound must be greater than lower bound. " + "The range is [lower_bound, upper_bound).") + + @validates("type") + def check_type(self, type_in): + """Check if type is valid.""" + if type_in not in HyperParamType.list_members(): + raise ValidationError("The type should be in %s." % HyperParamType.list_members()) + + @validates("source") + def check_source(self, source): + """Check if source is valid.""" + if source not in HyperParamSource.list_members(): + raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source)) + + @validates_schema + def check_combination(self, data, **kwargs): + """check the combination of parameters.""" + bound_key = HyperParamKey.BOUND.value + choice_key = HyperParamKey.CHOICE.value + type_key = HyperParamKey.TYPE.value + + # check bound and type + bounds = data.get(bound_key) + param_type = data.get(type_key) + if bounds is not None: + if param_type is None: + raise ValidationError("If %r is specified, the %r should be specified also." % + (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value)) + if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]): + raise ValidationError("No integer in 'bounds', please modify it.") + + # check bound and choice + if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data): + raise ValidationError("Only one of [%r, %r] should be specified." % + (bound_key, choice_key)) + + +class TargetSchema(Schema): + """Schema for target.""" + str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") + + group = fields.Str(error_messages=str_err_msg) + name = fields.Str(required=True, error_messages=str_err_msg) + goal = fields.Str(error_messages=str_err_msg) + + @validates("group") + def check_group(self, group): + """Check if bounds are valid.""" + if group not in TargetGroup.list_members(): + raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group)) + + @validates("goal") + def check_goal(self, goal): + """Check if source is valid.""" + if goal not in TargetGoal.list_members(): + raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal)) + + @validates_schema + def check_combination(self, data, **kwargs): + """check the combination of parameters.""" + if TargetKey.GROUP.value not in data: + # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. + return + name = data.get(TargetKey.NAME.value) + group = data.get(TargetKey.GROUP.value) + if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members(): + raise ValidationError({ + TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group}) + + +class OptimizerConfig(Schema): + """Define the search model condition parameter schema.""" + dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") + str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") + + summary_base_dir = fields.Str(required=True, error_messages=str_err_msg) + command = fields.Str(required=True, error_messages=str_err_msg) + tuner = fields.Dict(required=True, error_messages=dict_err_msg) + target = fields.Dict(required=True, error_messages=dict_err_msg) + parameters = fields.Dict(required=True, error_messages=dict_err_msg) + + @validates("tuner") + def check_tuner(self, data): + """Check tuner.""" + err = TunerSchema().validate(data) + if err: + raise ValidationError(err) + + @validates("parameters") + def check_parameters(self, parameters): + """Check parameters.""" + for name, value in parameters.items(): + err = ParameterSchema().validate(value) + if err: + raise ValidationError({name: err}) + + if HyperParamKey.SOURCE.value not in value: + # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. + continue + source = value.get(HyperParamKey.SOURCE.value) + if source == HyperParamSource.SYSTEM_DEFINED.value and \ + name not in TunableSystemDefinedParams.list_members(): + raise ValidationError({ + name: {"source": "This param is not system defined. Current source is: %s." % source}}) + + @validates("target") + def check_target(self, target): + """Check target.""" + err = TargetSchema().validate(target) + if err: + raise ValidationError(err) diff --git a/mindinsight/optimizer/hyper_config.py b/mindinsight/optimizer/hyper_config.py index 5e0ba3b6..15bd99c7 100644 --- a/mindinsight/optimizer/hyper_config.py +++ b/mindinsight/optimizer/hyper_config.py @@ -15,30 +15,73 @@ """Hyper config.""" import json import os -from attrdict import AttrDict -from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME -from mindinsight.optimizer.common.exceptions import HyperConfigError +from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT +from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError -_HYPER_CONFIG_LEN_LIMIT = 100000 +class AttributeDict(dict): + """A dict can be accessed by attribute.""" + def __init__(self, d=None): + super().__init__() + if d is not None: + for k, v in d.items(): + self[k] = v -class HyperConfig: - """ - Hyper config. + def __key(self, key): + """Get key.""" + return "" if key is None else key + + def __setattr__(self, key, value): + """Set attribute for object.""" + self[self.__key(key)] = value + + def __getattr__(self, key): + """ + Get attribute value according by attribute name. + + Args: + key (str): attribute name. + + Returns: + Any, attribute value. - Init hyper config: - >>> hyper_config = HyperConfig() + Raises: + AttributeError: If the key does not exists, will raise Exception. - Get suggest params: - >>> param_obj = hyper_config.params - >>> learning_rate = params.learning_rate + """ + value = self.get(self.__key(key)) + if value is None: + raise AttributeError("The attribute %r is not exist." % key) + return value - Get summary dir: - >>> summary_dir = hyper_config.summary_dir + def __getitem__(self, key): + """Get attribute value according by attribute name.""" + value = super().get(self.__key(key)) + if value is None: + raise AttributeError("The attribute %r is not exist." % key) + return value - Record by SummaryCollector: - >>> summary_cb = SummaryCollector(summary_dir) + def __setitem__(self, key, value): + """Set attribute for object.""" + return super().__setitem__(self.__key(key), value) + + +class HyperConfig: + """ + Hyper config. + 1. Init HyperConfig. + 2. Get suggested params and summary_dir. + 3. Record by SummaryCollector with summary_dir. + + Examples: + >>> hyper_config = HyperConfig() + >>> params = hyper_config.params + >>> learning_rate = params.learning_rate + >>> batch_size = params.batch_size + + >>> summary_dir = hyper_config.summary_dir + >>> summary_cb = SummaryCollector(summary_dir) """ def __init__(self): self._init_validate_hyper_config() @@ -47,10 +90,10 @@ class HyperConfig: """Init and validate hyper config.""" hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) if hyper_config is None: - raise HyperConfigError("Hyper config is not in system environment.") - if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT: - raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of " - "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) + raise HyperConfigEnvError("Hyper config is not in system environment.") + if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: + raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of " + "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) try: hyper_config = json.loads(hyper_config) @@ -60,8 +103,7 @@ class HyperConfig: raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) self._validate_hyper_config(hyper_config) - self._summary_dir = hyper_config.get('summary_dir') - self._param_obj = AttrDict(hyper_config.get('params')) + self._hyper_config = hyper_config def _validate_hyper_config(self, hyper_config): """Validate hyper config.""" @@ -86,9 +128,13 @@ class HyperConfig: @property def params(self): """Get params.""" - return self._param_obj + return AttributeDict(self._hyper_config.get('params')) @property def summary_dir(self): """Get train summary dir path.""" - return self._summary_dir + return self._hyper_config.get('summary_dir') + + @property + def custom_lineage_data(self): + return self._hyper_config.get('custom_lineage_data') diff --git a/mindinsight/optimizer/tuner.py b/mindinsight/optimizer/tuner.py index 116a501a..982f1dfc 100644 --- a/mindinsight/optimizer/tuner.py +++ b/mindinsight/optimizer/tuner.py @@ -22,16 +22,16 @@ import yaml from marshmallow import ValidationError -from mindinsight.datavisual.data_transform.data_manager import DataManager -from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path -from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX +from mindinsight.lineagemgr.model import get_lineage_table, LineageTable from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME -from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal -from mindinsight.optimizer.common.exceptions import OptimizerTerminateError +from mindinsight.optimizer.common.enums import TuneMethod +from mindinsight.optimizer.common.exceptions import OptimizerTerminateError, ConfigParamError from mindinsight.optimizer.common.log import logger +from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner from mindinsight.optimizer.utils.param_handler import organize_params_target +from mindinsight.optimizer.utils.utils import get_nested_message from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError _OK = 0 @@ -51,7 +51,6 @@ class Tuner: def __init__(self, config_path: str): self._config_info = self._validate_config(config_path) self._summary_base_dir = self._config_info.get('summary_base_dir') - self._data_manager = self._init_data_manager() self._dir_prefix = 'train' def _validate_config(self, config_path): @@ -65,11 +64,17 @@ class Tuner: except Exception as exc: raise UnknownError("Detail: %s." % str(exc)) - # need to add validation for config_info: command, summary_base_dir, target and params. + self._validate_config_schema(config_info) config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) self._make_summary_base_dir(config_info['summary_base_dir']) return config_info + def _validate_config_schema(self, config_info): + error = OptimizerConfig().validate(config_info) + if error: + err_msg = get_nested_message(error) + raise ConfigParamError(err_msg) + def _make_summary_base_dir(self, summary_base_dir): """Check and make summary_base_dir.""" if not os.path.exists(summary_base_dir): @@ -82,13 +87,6 @@ class Tuner: except OSError as exc: raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) - def _init_data_manager(self): - """Initialize data_manager.""" - data_manager = DataManager(summary_base_dir=self._summary_base_dir) - data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) - - return data_manager - def _normalize_path(self, param_name, path): """Normalize config path.""" path = os.path.realpath(path) @@ -104,10 +102,8 @@ class Tuner: def _update_from_lineage(self): """Update lineage from lineagemgr.""" - self._data_manager.start_load_data(reload_interval=0).join() - try: - lineage_table = get_lineage_table(self._data_manager) + lineage_table = get_lineage_table(summary_base_dir=self._summary_base_dir) except MindInsightException as err: logger.info("Can not query lineage. Detail: %s", str(err)) lineage_table = None @@ -122,12 +118,14 @@ class Tuner: tuner = self._config_info.get('tuner') for _ in range(max_expr_times): self._update_from_lineage() - suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name")) + suggestion, user_defined_info = self._suggest(self._lineage_table, params_info, target_info, tuner) hyper_config = { 'params': suggestion, - 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}') + 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}'), + 'custom_lineage_data': user_defined_info } + logger.info("Suggest values are: %s.", suggestion) os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) s = subprocess.Popen(shlex.split(command)) s.wait() @@ -135,29 +133,26 @@ class Tuner: logger.error("An error occurred during execution, the auto tuning will be terminated.") raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") - def _get_tuner(self, tune_method=TuneMethod.GP.value): + def _get_tuner(self, tuner): """Get tuner.""" - if tune_method.lower() not in TuneMethod.list_members(): + if tuner is None: + return GPBaseTuner() + + tuner_name = tuner.get("name").lower() + if tuner_name not in TuneMethod.list_members(): raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) + args = tuner.get("args") + if args is not None and args.get("method") is not None: + return GPBaseTuner(args.get("method")) + # Only support gaussian process regressor currently. return GPBaseTuner() - def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method): + def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, tuner): """Get suggestions for targets.""" - tuner = self._get_tuner(method) - target_name = target_info[TargetKey.NAME.value] - if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric': - target_name = METRIC_PREFIX + target_name - param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name) - - if not param_matrix.empty: - suggestion = tuner.suggest([], [], params_info) - else: - target_column = target_matrix[target_name].reshape((-1, 1)) - if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: - target_column = -target_column - - suggestion = tuner.suggest(param_matrix, target_column, params_info) - - return suggestion + tuner = self._get_tuner(tuner) + param_matrix, target_column = organize_params_target(lineage_table, params_info, target_info) + suggestion, user_defined_info = tuner.suggest(param_matrix, target_column, params_info) + + return suggestion, user_defined_info diff --git a/mindinsight/optimizer/tuners/gp_tuner.py b/mindinsight/optimizer/tuners/gp_tuner.py index adbf8976..ff2b69e1 100644 --- a/mindinsight/optimizer/tuners/gp_tuner.py +++ b/mindinsight/optimizer/tuners/gp_tuner.py @@ -22,10 +22,11 @@ from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import Matern from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey +from mindinsight.optimizer.common.log import logger +from mindinsight.optimizer.tuners.base_tuner import BaseTuner from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type from mindinsight.optimizer.utils.transformer import Transformer from mindinsight.utils.exceptions import ParamValueError -from mindinsight.optimizer.tuners.base_tuner import BaseTuner class AcquisitionFunction: @@ -141,8 +142,10 @@ class GPBaseTuner(BaseTuner): x_seeds = generate_arrays(params_info, n_iter) for x_try in x_seeds: - res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), - x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), + x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") if not res.success: continue @@ -164,6 +167,8 @@ class GPBaseTuner(BaseTuner): min_lineage_rows = 2 if not np.array(params).any() or params.shape[0] < min_lineage_rows: + logger.info("Without valid histories or the rows of lineages < %s, " + "parameters will be recommended randomly.", min_lineage_rows) suggestion = generate_arrays(params_info) else: self._gp.fit(params, target) @@ -174,5 +179,5 @@ class GPBaseTuner(BaseTuner): params_info=params_info ) - suggestion = Transformer.transform_list_to_dict(params_info, suggestion) - return suggestion + suggestion, user_defined_info = Transformer.transform_list_to_dict(params_info, suggestion) + return suggestion, user_defined_info diff --git a/mindinsight/optimizer/utils/param_handler.py b/mindinsight/optimizer/utils/param_handler.py index 65736654..caca6224 100644 --- a/mindinsight/optimizer/utils/param_handler.py +++ b/mindinsight/optimizer/utils/param_handler.py @@ -14,8 +14,9 @@ # ============================================================================ """Utils for params.""" import numpy as np -from mindinsight.lineagemgr.model import LineageTable -from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType +from mindinsight.lineagemgr.model import LineageTable, USER_DEFINED_PREFIX, METRIC_PREFIX +from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType, HyperParamSource, TargetKey, \ + TargetGoal, TunableSystemDefinedParams, TargetGroup, SystemDefinedTargets from mindinsight.optimizer.common.log import logger @@ -54,7 +55,6 @@ def match_value_type(array, params_info: dict): array_new = [] index = 0 for _, param_info in params_info.items(): - param_type = param_info[HyperParamKey.TYPE.value] value = array[index] if HyperParamKey.BOUND.value in param_info: bound = param_info[HyperParamKey.BOUND.value] @@ -64,7 +64,7 @@ def match_value_type(array, params_info: dict): choices = param_info[HyperParamKey.CHOICE.value] nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) value = choices[nearest_index] - if param_type == HyperParamType.INT.value: + if param_info.get(HyperParamKey.TYPE.value) == HyperParamType.INT.value: value = int(value) if HyperParamKey.DECIMAL.value in param_info: value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) @@ -73,20 +73,51 @@ def match_value_type(array, params_info: dict): return array_new -def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name): +def organize_params_target(lineage_table: LineageTable, params_info: dict, target_info): """Organize params and target.""" empty_result = np.array([]) if lineage_table is None: return empty_result, empty_result - param_keys = list(params_info.keys()) + param_keys = [] + for param_key, param_info in params_info.items(): + # It will be a user_defined param: + # 1. if 'source' is specified as 'user_defined' + # 2. if 'source' is not specified and the param is not a system_defined key + source = param_info.get(HyperParamKey.SOURCE.value) + prefix = _get_prefix(param_key, source, HyperParamSource.USER_DEFINED.value, + USER_DEFINED_PREFIX, TunableSystemDefinedParams.list_members()) + param_key = f'{prefix}{param_key}' + if prefix == USER_DEFINED_PREFIX: + param_info[HyperParamKey.SOURCE.value] = HyperParamSource.USER_DEFINED.value + else: + param_info[HyperParamKey.SOURCE.value] = HyperParamSource.SYSTEM_DEFINED.value + + param_keys.append(param_key) + target_name = target_info[TargetKey.NAME.value] + group = target_info.get(TargetKey.GROUP.value) + prefix = _get_prefix(target_name, group, TargetGroup.METRIC.value, + METRIC_PREFIX, SystemDefinedTargets.list_members()) + target_name = prefix + target_name lineage_df = lineage_table.dataframe_data try: lineage_df = lineage_df[param_keys + [target_name]] lineage_df = lineage_df.dropna(axis=0, how='any') - return lineage_df[param_keys], lineage_df[target_name] + + target_column = np.array(lineage_df[target_name]) + if TargetKey.GOAL.value in target_info and \ + target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: + target_column = -target_column + + return np.array(lineage_df[param_keys]), target_column except KeyError as exc: logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." "Detail: %s.", str(exc)) return empty_result, empty_result + + +def _get_prefix(name, field, other_defined_field, other_defined_prefix, system_defined_fields): + if (field == other_defined_field) or (field is None and name not in system_defined_fields): + return other_defined_prefix + return '' diff --git a/mindinsight/optimizer/utils/transformer.py b/mindinsight/optimizer/utils/transformer.py index f15f68b4..8ad8fd9c 100644 --- a/mindinsight/optimizer/utils/transformer.py +++ b/mindinsight/optimizer/utils/transformer.py @@ -14,6 +14,7 @@ # ============================================================================ """Transformer.""" from mindinsight.optimizer.utils.param_handler import match_value_type +from mindinsight.optimizer.common.enums import HyperParamSource, HyperParamKey class Transformer: @@ -23,7 +24,12 @@ class Transformer: """Transform from tuner.""" suggest_list = match_value_type(suggest_list, params_info) param_dict = {} + user_defined_info = {} for index, param_name in enumerate(params_info): - param_dict.update({param_name: suggest_list[index]}) + param_item = {param_name: suggest_list[index]} + param_dict.update(param_item) + source = params_info.get(param_name).get(HyperParamKey.SOURCE.value) + if source is not None and source == HyperParamSource.USER_DEFINED.value: + user_defined_info.update(param_item) - return param_dict + return param_dict, user_defined_info diff --git a/mindinsight/optimizer/utils/utils.py b/mindinsight/optimizer/utils/utils.py index d0ba698d..a9cacbb2 100644 --- a/mindinsight/optimizer/utils/utils.py +++ b/mindinsight/optimizer/utils/utils.py @@ -68,3 +68,18 @@ def is_simple_numpy_number(dtype): return True return False + + +def get_nested_message(info: dict, out_err_msg=""): + """Get error message from the error dict generated by schema validation.""" + if not isinstance(info, dict): + if isinstance(info, list): + info = info[0] + return f'Error in {out_err_msg}: {info}' + for key in info: + if isinstance(key, str) and key != '_schema': + if out_err_msg: + out_err_msg = f'{out_err_msg}.{key}' + else: + out_err_msg = key + return get_nested_message(info[key], out_err_msg) diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 480cde9f..00622187 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -98,3 +98,5 @@ class OptimizerErrors(Enum): CORRELATION_NAN = 2 HYPER_CONFIG_ERROR = 3 OPTIMIZER_TERMINATE = 4 + CONFIG_PARAM_ERROR = 5 + HYPER_CONFIG_ENV_ERROR = 6 diff --git a/requirements.txt b/requirements.txt index 82bd8894..7af7a8b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,9 @@ marshmallow>=2.19.2 numpy>=1.17.0 protobuf>=3.8.0 psutil>=5.6.1 -pyyaml>=5.3 +pyyaml>=5.3.1 scipy>=1.3.3 -scikit-learn>=0.23.1 +scikit-learn>=0.21.2 six>=1.12.0 Werkzeug>=1.0.0 pandas>=1.0.4 diff --git a/setup.py b/setup.py index dc433dca..4b0ccd81 100644 --- a/setup.py +++ b/setup.py @@ -208,6 +208,7 @@ if __name__ == '__main__': 'mindinsight=mindinsight.utils.command:main', 'mindconverter=mindinsight.mindconverter.cli:cli_entry', 'mindwizard=mindinsight.wizard.cli:cli_entry', + 'mindoptimizer=mindinsight.optimizer.cli:cli_entry', ], }, python_requires='>=3.7', diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index 93338745..bafc2484 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -98,14 +98,14 @@ class TestModelLineage(TestCase): train_callback.end(RunContext(self.run_context)) LINEAGE_DATA_MANAGER.start_load_data().join() - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 run_context = self.run_context run_context['epoch_num'] = 14 train_callback.end(RunContext(run_context)) LINEAGE_DATA_MANAGER.start_load_data().join() - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 @pytest.mark.scene_eval(3) @@ -198,7 +198,7 @@ class TestModelLineage(TestCase): train_callback.end(RunContext(run_context_customized)) LINEAGE_DATA_MANAGER.start_load_data().join() - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' diff --git a/tests/st/func/lineagemgr/test_model.py b/tests/st/func/lineagemgr/test_model.py index 9e863ca2..89a0eb2c 100644 --- a/tests/st/func/lineagemgr/test_model.py +++ b/tests/st/func/lineagemgr/test_model.py @@ -190,7 +190,7 @@ class TestModelApi(TestCase): search_condition = { 'sorted_name': 'summary_dir' } - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') @@ -228,7 +228,7 @@ class TestModelApi(TestCase): ], 'count': 2 } - partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) + partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') @@ -266,7 +266,7 @@ class TestModelApi(TestCase): ], 'count': 2 } - partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) + partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') @@ -295,7 +295,7 @@ class TestModelApi(TestCase): ], 'count': 3 } - partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) + partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res1.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') @@ -314,7 +314,7 @@ class TestModelApi(TestCase): 'object': [], 'count': 0 } - partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) + partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) assert expect_result == partial_res2 @pytest.mark.level0 @@ -335,7 +335,7 @@ class TestModelApi(TestCase): 'eq': self._empty_train_id } } - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) assert expect_result == res @pytest.mark.level0 @@ -366,7 +366,7 @@ class TestModelApi(TestCase): ], 'count': 1 } - res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) + res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) assert expect_result == res @pytest.mark.level0 @@ -386,6 +386,7 @@ class TestModelApi(TestCase): 'The search_condition element summary_dir should be dict.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -398,6 +399,7 @@ class TestModelApi(TestCase): 'The sorted_name must exist when sorted_type exists.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -408,6 +410,7 @@ class TestModelApi(TestCase): 'Invalid search_condition type, it should be dict.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -420,6 +423,7 @@ class TestModelApi(TestCase): 'The limit must be int.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -440,6 +444,7 @@ class TestModelApi(TestCase): 'The offset must be int.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -454,6 +459,7 @@ class TestModelApi(TestCase): 'The search attribute not supported.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -475,6 +481,7 @@ class TestModelApi(TestCase): 'The sorted_type must be ascending or descending', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -489,6 +496,7 @@ class TestModelApi(TestCase): 'The compare condition should be in', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -503,6 +511,7 @@ class TestModelApi(TestCase): 'The parameter metric/accuracy is invalid.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -526,7 +535,7 @@ class TestModelApi(TestCase): 'object': [], 'count': 0 } - partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) + partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) assert expect_result == partial_res1 # the (offset + 1) * limit > count @@ -542,7 +551,7 @@ class TestModelApi(TestCase): 'object': [], 'count': 2 } - partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) + partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) assert expect_result == partial_res2 @pytest.mark.level0 @@ -566,6 +575,7 @@ class TestModelApi(TestCase): f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -589,6 +599,7 @@ class TestModelApi(TestCase): "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) @@ -610,6 +621,7 @@ class TestModelApi(TestCase): 'The sorted_name must be in', filter_summary_lineage, LINEAGE_DATA_MANAGER, + None, search_condition ) diff --git a/tests/ut/lineagemgr/common/validator/test_validate.py b/tests/ut/lineagemgr/common/validator/test_validate.py index f6cdab9e..31adea03 100644 --- a/tests/ut/lineagemgr/common/validator/test_validate.py +++ b/tests/ut/lineagemgr/common/validator/test_validate.py @@ -23,7 +23,7 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError class TestValidateSearchModelCondition(TestCase): - """Test the mothod of validate_search_model_condition.""" + """Test the method of validate_search_model_condition.""" def test_validate_search_model_condition_param_type_error(self): """Test the method of validate_search_model_condition with LineageParamTypeError.""" condition = { diff --git a/tests/ut/optimizer/common/__init__.py b/tests/ut/optimizer/common/__init__.py new file mode 100644 index 00000000..f3402518 --- /dev/null +++ b/tests/ut/optimizer/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test common module.""" diff --git a/tests/ut/optimizer/common/validator/__init__.py b/tests/ut/optimizer/common/validator/__init__.py new file mode 100644 index 00000000..461c7904 --- /dev/null +++ b/tests/ut/optimizer/common/validator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test validator.""" diff --git a/tests/ut/optimizer/common/validator/test_optimizer_config.py b/tests/ut/optimizer/common/validator/test_optimizer_config.py new file mode 100644 index 00000000..7be0379a --- /dev/null +++ b/tests/ut/optimizer/common/validator/test_optimizer_config.py @@ -0,0 +1,161 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test optimizer config schema.""" +from copy import deepcopy + +from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig +from mindinsight.optimizer.common.enums import TargetGroup, HyperParamSource + +_BASE_DATA = { + 'command': 'python ./test.py', + 'summary_base_dir': './demo_lineage_r0.3', + 'tuner': { + 'name': 'gp', + 'args': { + 'method': 'ucb' + } + }, + 'target': { + 'group': 'metric', + 'name': 'Accuracy', + 'goal': 'maximize' + }, + 'parameters': { + 'learning_rate': { + 'bounds': [0.0001, 0.001], + 'type': 'float' + }, + 'batch_size': { + 'choice': [32, 64, 128, 256], + 'type': 'int' + }, + 'decay_step': { + 'choice': [20], + 'type': 'int' + } + } +} + + +class TestOptimizerConfig: + """Test the method of validate_search_model_condition.""" + _config_dict = dict(_BASE_DATA) + + def test_config_dict_with_wrong_type(self): + """Test config dict with wrong type.""" + config_dict = deepcopy(self._config_dict) + init_list = ['a'] + init_str = 'a' + + config_dict['command'] = init_list + config_dict['summary_base_dir'] = init_list + config_dict['target']['name'] = init_list + config_dict['target']['goal'] = init_list + config_dict['parameters']['learning_rate']['bounds'] = init_str + config_dict['parameters']['learning_rate']['choice'] = init_str + expected_err = { + 'command': ["Value type should be 'str'."], + 'parameters': { + 'learning_rate': { + 'bounds': ["Value type should be 'list'."], + 'choice': ["Value type should be 'list'."] + } + }, + 'summary_base_dir': ["Value type should be 'str'."], + 'target': { + 'name': ["Value type should be 'str'."], + 'goal': ["Value type should be 'str'."] + } + } + err = OptimizerConfig().validate(config_dict) + assert expected_err == err + + def test_config_dict_with_wrong_value(self): + """Test config dict with wrong value.""" + config_dict = deepcopy(self._config_dict) + init_list = ['a'] + init_str = 'a' + + config_dict['target']['group'] = init_str + config_dict['target']['goal'] = init_str + config_dict['tuner']['name'] = init_str + config_dict['parameters']['learning_rate']['bounds'] = init_list + config_dict['parameters']['learning_rate']['choice'] = init_list + config_dict['parameters']['learning_rate']['type'] = init_str + expected_err = { + 'parameters': { + 'learning_rate': { + 'bounds': { + 0: ['Value(s) should be integer or float.'] + }, + 'choice': { + 0: ['Value(s) should be integer or float.'] + }, + 'type': ["The type should be in ['int', 'float']."] + } + }, + 'target': { + 'goal': ["Value should be in ['maximize', 'minimize']. Current value is a."], + 'group': ["Value should be in ['system_defined', 'metric']. Current value is a."]}, + 'tuner': { + 'name': ['Must be one of: gp.'] + } + } + err = OptimizerConfig().validate(config_dict) + assert expected_err == err + + def test_target_combination(self): + """Test target combination.""" + config_dict = deepcopy(self._config_dict) + + config_dict['target']['group'] = TargetGroup.SYSTEM_DEFINED.value + config_dict['target']['name'] = 'a' + expected_err = { + 'target': { + 'group': 'This target is not system defined. Current group is: system_defined.' + } + } + err = OptimizerConfig().validate(config_dict) + assert expected_err == err + + def test_parameters_combination1(self): + """Test parameters combination.""" + config_dict = deepcopy(self._config_dict) + + config_dict['parameters']['decay_step']['source'] = HyperParamSource.SYSTEM_DEFINED.value + expected_err = { + 'parameters': { + 'decay_step': { + 'source': 'This param is not system defined. Current source is: system_defined.' + } + } + } + err = OptimizerConfig().validate(config_dict) + assert expected_err == err + + def test_parameters_combination2(self): + """Test parameters combination.""" + config_dict = deepcopy(self._config_dict) + + config_dict['parameters']['decay_step']['bounds'] = [1, 40] + expected_err = { + 'parameters': { + 'decay_step': { + '_schema': ["Only one of ['bounds', 'choice'] should be specified."] + } + } + } + err = OptimizerConfig().validate(config_dict) + assert expected_err == err diff --git a/tests/ut/optimizer/utils/__init__.py b/tests/ut/optimizer/utils/__init__.py new file mode 100644 index 00000000..95ed9df4 --- /dev/null +++ b/tests/ut/optimizer/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test utils.""" diff --git a/tests/ut/optimizer/test_utils.py b/tests/ut/optimizer/utils/test_utils.py similarity index 100% rename from tests/ut/optimizer/test_utils.py rename to tests/ut/optimizer/utils/test_utils.py