1. add cli entry 2. add more validation for hyper config 3. add validator for config dict 4. add st for optimizer config validator 5. add custom lineage for hyper configtags/v1.1.0
| @@ -43,7 +43,7 @@ def get_optimize_targets(): | |||||
| def _get_optimize_targets(data_manager, search_condition=None): | def _get_optimize_targets(data_manager, search_condition=None): | ||||
| """Get optimize targets.""" | """Get optimize targets.""" | ||||
| table = get_lineage_table(data_manager, search_condition) | |||||
| table = get_lineage_table(data_manager=data_manager, search_condition=search_condition) | |||||
| target_summaries = [] | target_summaries = [] | ||||
| for target in table.target_names: | for target in table.target_names: | ||||
| @@ -15,6 +15,7 @@ | |||||
| """This file is used to parse lineage info.""" | """This file is used to parse lineage info.""" | ||||
| import os | import os | ||||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \ | from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \ | ||||
| LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \ | LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \ | ||||
| MindInsightException | MindInsightException | ||||
| @@ -177,13 +178,49 @@ class LineageParser: | |||||
| class LineageOrganizer: | class LineageOrganizer: | ||||
| """Lineage organizer.""" | """Lineage organizer.""" | ||||
| def __init__(self, data_manager): | |||||
| def __init__(self, data_manager=None, summary_base_dir=None): | |||||
| self._data_manager = data_manager | self._data_manager = data_manager | ||||
| self._summary_base_dir = summary_base_dir | |||||
| self._check_params() | |||||
| self._super_lineage_objs = {} | self._super_lineage_objs = {} | ||||
| self._organize_from_cache() | self._organize_from_cache() | ||||
| self._organize_from_disk() | |||||
| def _check_params(self): | |||||
| """Check params.""" | |||||
| if self._data_manager is not None and self._summary_base_dir is not None: | |||||
| self._summary_base_dir = None | |||||
| def _organize_from_disk(self): | |||||
| """Organize lineage objs from disk.""" | |||||
| if self._summary_base_dir is None: | |||||
| return | |||||
| summary_watcher = SummaryWatcher() | |||||
| relative_dirs = summary_watcher.list_summary_directories( | |||||
| summary_base_dir=self._summary_base_dir | |||||
| ) | |||||
| no_lineage_count = 0 | |||||
| for item in relative_dirs: | |||||
| relative_dir = item.get('relative_path') | |||||
| update_time = item.get('update_time') | |||||
| abs_summary_dir = os.path.realpath(os.path.join(self._summary_base_dir, relative_dir)) | |||||
| try: | |||||
| lineage_parser = LineageParser(relative_dir, abs_summary_dir, update_time) | |||||
| super_lineage_obj = lineage_parser.super_lineage_obj | |||||
| if super_lineage_obj is not None: | |||||
| self._super_lineage_objs.update({abs_summary_dir: super_lineage_obj}) | |||||
| except LineageFileNotFoundError: | |||||
| no_lineage_count += 1 | |||||
| if no_lineage_count == len(relative_dirs): | |||||
| logger.info('There is no summary log file under summary_base_dir.') | |||||
| def _organize_from_cache(self): | def _organize_from_cache(self): | ||||
| """Organize lineage objs from cache.""" | """Organize lineage objs from cache.""" | ||||
| if self._data_manager is None: | |||||
| return | |||||
| brief_cache = self._data_manager.get_brief_cache() | brief_cache = self._data_manager.get_brief_cache() | ||||
| cache_items = brief_cache.cache_items | cache_items = brief_cache.cache_items | ||||
| for relative_dir, cache_train_job in cache_items.items(): | for relative_dir, cache_train_job in cache_items.items(): | ||||
| @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySumm | |||||
| from mindinsight.lineagemgr.common.log import logger as log | from mindinsight.lineagemgr.common.log import logger as log | ||||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | ||||
| from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition | from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition | ||||
| from mindinsight.lineagemgr.common.validator.validate_path import validate_and_normalize_path | |||||
| from mindinsight.lineagemgr.lineage_parser import LineageOrganizer | from mindinsight.lineagemgr.lineage_parser import LineageOrganizer | ||||
| from mindinsight.lineagemgr.querier.querier import Querier | from mindinsight.lineagemgr.querier.querier import Querier | ||||
| from mindinsight.optimizer.common.enums import ReasonCode | from mindinsight.optimizer.common.enums import ReasonCode | ||||
| @@ -33,7 +34,7 @@ USER_DEFINED_PREFIX = "[U]" | |||||
| USER_DEFINED_INFO_LIMIT = 100 | USER_DEFINED_INFO_LIMIT = 100 | ||||
| def filter_summary_lineage(data_manager, search_condition=None): | |||||
| def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None): | |||||
| """ | """ | ||||
| Filter summary lineage from data_manager or parsing from summaries. | Filter summary lineage from data_manager or parsing from summaries. | ||||
| @@ -43,8 +44,18 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||||
| Args: | Args: | ||||
| data_manager (DataManager): Data manager defined as | data_manager (DataManager): Data manager defined as | ||||
| mindinsight.datavisual.data_transform.data_manager.DataManager | mindinsight.datavisual.data_transform.data_manager.DataManager | ||||
| summary_base_dir (str): The summary base directory. It contains summary | |||||
| directories generated by training. | |||||
| search_condition (dict): The search condition. | search_condition (dict): The search condition. | ||||
| """ | """ | ||||
| if data_manager is None and summary_base_dir is None: | |||||
| raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.") | |||||
| if data_manager is None: | |||||
| summary_base_dir = validate_and_normalize_path(summary_base_dir, 'summary_base_dir') | |||||
| else: | |||||
| summary_base_dir = data_manager.summary_base_dir | |||||
| search_condition = {} if search_condition is None else search_condition | search_condition = {} if search_condition is None else search_condition | ||||
| try: | try: | ||||
| @@ -56,7 +67,7 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||||
| raise LineageSearchConditionParamError(str(error.message)) | raise LineageSearchConditionParamError(str(error.message)) | ||||
| try: | try: | ||||
| lineage_objects = LineageOrganizer(data_manager).super_lineage_objs | |||||
| lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs | |||||
| result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition) | result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition) | ||||
| except LineageSummaryParseException: | except LineageSummaryParseException: | ||||
| result = {'object': [], 'count': 0} | result = {'object': [], 'count': 0} | ||||
| @@ -68,12 +79,13 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||||
| return result | return result | ||||
| def get_flattened_lineage(data_manager, search_condition=None): | |||||
| def get_flattened_lineage(data_manager=None, summary_base_dir=None, search_condition=None): | |||||
| """ | """ | ||||
| Get lineage data in a table from data manager. | Get lineage data in a table from data manager. | ||||
| Args: | Args: | ||||
| data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. | data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. | ||||
| summary_base_dir (str): The base directory for train jobs. | |||||
| search_condition (dict): The search condition. | search_condition (dict): The search condition. | ||||
| Returns: | Returns: | ||||
| @@ -81,7 +93,7 @@ def get_flattened_lineage(data_manager, search_condition=None): | |||||
| """ | """ | ||||
| flatten_dict, user_count = {'train_id': []}, 0 | flatten_dict, user_count = {'train_id': []}, 0 | ||||
| lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", []) | |||||
| lineages = filter_summary_lineage(data_manager, summary_base_dir, search_condition).get("object", []) | |||||
| for index, lineage in enumerate(lineages): | for index, lineage in enumerate(lineages): | ||||
| flatten_dict['train_id'].append(lineage.get("summary_dir")) | flatten_dict['train_id'].append(lineage.get("summary_dir")) | ||||
| for key, val in _flatten_lineage(lineage.get('model_lineage', {})): | for key, val in _flatten_lineage(lineage.get('model_lineage', {})): | ||||
| @@ -222,9 +234,9 @@ class LineageTable: | |||||
| return self._drop_columns_info | return self._drop_columns_info | ||||
| def get_lineage_table(data_manager, search_condition=None): | |||||
| def get_lineage_table(data_manager=None, summary_base_dir=None, search_condition=None): | |||||
| """Get lineage table from data_manager.""" | """Get lineage table from data_manager.""" | ||||
| lineage_table = get_flattened_lineage(data_manager, search_condition) | |||||
| lineage_table = get_flattened_lineage(data_manager, summary_base_dir, search_condition) | |||||
| lineage_table = LineageTable(pd.DataFrame(lineage_table)) | lineage_table = LineageTable(pd.DataFrame(lineage_table)) | ||||
| return lineage_table | return lineage_table | ||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Command module.""" | |||||
| import argparse | |||||
| import os | |||||
| import sys | |||||
| import mindinsight | |||||
| from mindinsight.optimizer.tuner import Tuner | |||||
| class ConfigAction(argparse.Action): | |||||
| """Summary base dir action class definition.""" | |||||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from argparse.Action. | |||||
| Args: | |||||
| parser_in (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Option string for specific argument name. | |||||
| """ | |||||
| config_path = os.path.realpath(values) | |||||
| if not os.path.exists(config_path): | |||||
| parser_in.error(f'{option_string} {config_path} not exists.') | |||||
| setattr(namespace, self.dest, config_path) | |||||
| class IterAction(argparse.Action): | |||||
| """Summary base dir action class definition.""" | |||||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from argparse.Action. | |||||
| Args: | |||||
| parser_in (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Option string for specific argument name. | |||||
| """ | |||||
| iter_times = values | |||||
| if iter_times <= 0: | |||||
| parser_in.error(f'{option_string} {iter_times} should be a positive integer.') | |||||
| setattr(namespace, self.dest, iter_times) | |||||
| parser = argparse.ArgumentParser( | |||||
| prog='mindoptimizer', | |||||
| description='MindOptimizer CLI entry point (version: {})'.format(mindinsight.__version__)) | |||||
| parser.add_argument( | |||||
| '--version', | |||||
| action='version', | |||||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||||
| parser.add_argument( | |||||
| '--config', | |||||
| type=str, | |||||
| action=ConfigAction, | |||||
| required=True, | |||||
| default=os.path.join(os.getcwd(), 'output'), | |||||
| help="Specify path for config file." | |||||
| ) | |||||
| parser.add_argument( | |||||
| '--iter', | |||||
| type=int, | |||||
| action=IterAction, | |||||
| default=1, | |||||
| help="Optional, specify run times for the command in config file." | |||||
| ) | |||||
| def cli_entry(): | |||||
| """Cli entry.""" | |||||
| argv = sys.argv[1:] | |||||
| if not argv: | |||||
| argv = ['-h'] | |||||
| args = parser.parse_args(argv) | |||||
| else: | |||||
| args = parser.parse_args() | |||||
| tuner = Tuner(args.config) | |||||
| tuner.optimize(max_expr_times=args.iter) | |||||
| @@ -14,3 +14,4 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Common constants for optimizer.""" | """Common constants for optimizer.""" | ||||
| HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" | HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" | ||||
| HYPER_CONFIG_LEN_LIMIT = 100000 | |||||
| @@ -48,12 +48,17 @@ class TuneMethod(BaseEnum): | |||||
| GP = 'gp' | GP = 'gp' | ||||
| class GPSupportArgs(BaseEnum): | |||||
| METHOD = 'method' | |||||
| class HyperParamKey(BaseEnum): | class HyperParamKey(BaseEnum): | ||||
| """Config keys for hyper parameters.""" | """Config keys for hyper parameters.""" | ||||
| BOUND = 'bounds' | BOUND = 'bounds' | ||||
| CHOICE = 'choice' | CHOICE = 'choice' | ||||
| DECIMAL = 'decimal' | DECIMAL = 'decimal' | ||||
| TYPE = 'type' | TYPE = 'type' | ||||
| SOURCE = 'source' | |||||
| class HyperParamType(BaseEnum): | class HyperParamType(BaseEnum): | ||||
| @@ -73,3 +78,25 @@ class TargetGoal(BaseEnum): | |||||
| """Goal for target.""" | """Goal for target.""" | ||||
| MAXIMUM = 'maximize' | MAXIMUM = 'maximize' | ||||
| MINIMUM = 'minimize' | MINIMUM = 'minimize' | ||||
| class HyperParamSource(BaseEnum): | |||||
| SYSTEM_DEFINED = 'system_defined' | |||||
| USER_DEFINED = 'user_defined' | |||||
| class TargetGroup(BaseEnum): | |||||
| SYSTEM_DEFINED = 'system_defined' | |||||
| METRIC = 'metric' | |||||
| class TunableSystemDefinedParams(BaseEnum): | |||||
| """Tunable metadata keys of lineage collection.""" | |||||
| BATCH_SIZE = 'batch_size' | |||||
| EPOCH = 'epoch' | |||||
| LEARNING_RATE = 'learning_rate' | |||||
| class SystemDefinedTargets(BaseEnum): | |||||
| """System defined targets""" | |||||
| LOSS = 'loss' | |||||
| @@ -48,3 +48,19 @@ class OptimizerTerminateError(MindInsightException): | |||||
| super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | ||||
| error_msg, | error_msg, | ||||
| http_code=400) | http_code=400) | ||||
| class ConfigParamError(MindInsightException): | |||||
| """Hyper config error.""" | |||||
| def __init__(self, error_msg="Invalid parameter."): | |||||
| super(ConfigParamError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | |||||
| error_msg, | |||||
| http_code=400) | |||||
| class HyperConfigEnvError(MindInsightException): | |||||
| """Hyper config error.""" | |||||
| def __init__(self, error_msg="Hyper config is not correct."): | |||||
| super(HyperConfigEnvError, self).__init__(OptimizerErrors.HYPER_CONFIG_ENV_ERROR, | |||||
| error_msg, | |||||
| http_code=400) | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Common validator modules for optimizer.""" | |||||
| @@ -0,0 +1,202 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Validator for optimizer config.""" | |||||
| import math | |||||
| from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema | |||||
| from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \ | |||||
| HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \ | |||||
| HyperParamKey, SystemDefinedTargets | |||||
| _BOUND_LEN = 2 | |||||
| _NUMBER_ERR_MSG = "Value(s) should be integer or float." | |||||
| _TYPE_ERR_MSG = "Value type should be %r." | |||||
| _VALUE_ERR_MSG = "Value should be in %s. Current value is %s." | |||||
| def _generate_schema_err_msg(err_msg, *args): | |||||
| """Organize error messages.""" | |||||
| if args: | |||||
| err_msg = err_msg % args | |||||
| return {"invalid": err_msg} | |||||
| def include_integer(low, high): | |||||
| """Check if the range [low, high) includes integer.""" | |||||
| def _in_range(num, low, high): | |||||
| """check if num in [low, high)""" | |||||
| return low <= num < high | |||||
| if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high): | |||||
| return True | |||||
| return False | |||||
| class TunerSchema(Schema): | |||||
| """Schema for tuner.""" | |||||
| dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") | |||||
| name = fields.Str(required=True, | |||||
| validate=validate.OneOf(TuneMethod.list_members()), | |||||
| error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members())) | |||||
| args = fields.Dict(error_messages=dict_err_msg) | |||||
| @validates("args") | |||||
| def check_args(self, data): | |||||
| """Check args for tuner.""" | |||||
| data_keys = list(data.keys()) | |||||
| support_args = GPSupportArgs.list_members() | |||||
| if not set(data_keys).issubset(set(support_args)): | |||||
| raise ValidationError("Only support setting %s for tuner. " | |||||
| "Current key(s): %s." % (support_args, data_keys)) | |||||
| method = data.get(GPSupportArgs.METHOD.value) | |||||
| if not isinstance(method, str): | |||||
| raise ValidationError("The 'method' type should be str.") | |||||
| if method not in AcquisitionFunctionEnum.list_members(): | |||||
| raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." % | |||||
| (AcquisitionFunctionEnum.list_members(), method)) | |||||
| class ParameterSchema(Schema): | |||||
| """Schema for parameter.""" | |||||
| number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG) | |||||
| list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list") | |||||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||||
| bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) | |||||
| choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) | |||||
| type = fields.Str(error_messages=list_err_msg) | |||||
| source = fields.Str(error_messages=str_err_msg) | |||||
| @validates("bounds") | |||||
| def check_bounds(self, bounds): | |||||
| """Check if bounds are valid.""" | |||||
| if len(bounds) != _BOUND_LEN: | |||||
| raise ValidationError("Length of bounds should be %s." % _BOUND_LEN) | |||||
| if bounds[1] <= bounds[0]: | |||||
| raise ValidationError("The upper bound must be greater than lower bound. " | |||||
| "The range is [lower_bound, upper_bound).") | |||||
| @validates("type") | |||||
| def check_type(self, type_in): | |||||
| """Check if type is valid.""" | |||||
| if type_in not in HyperParamType.list_members(): | |||||
| raise ValidationError("The type should be in %s." % HyperParamType.list_members()) | |||||
| @validates("source") | |||||
| def check_source(self, source): | |||||
| """Check if source is valid.""" | |||||
| if source not in HyperParamSource.list_members(): | |||||
| raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source)) | |||||
| @validates_schema | |||||
| def check_combination(self, data, **kwargs): | |||||
| """check the combination of parameters.""" | |||||
| bound_key = HyperParamKey.BOUND.value | |||||
| choice_key = HyperParamKey.CHOICE.value | |||||
| type_key = HyperParamKey.TYPE.value | |||||
| # check bound and type | |||||
| bounds = data.get(bound_key) | |||||
| param_type = data.get(type_key) | |||||
| if bounds is not None: | |||||
| if param_type is None: | |||||
| raise ValidationError("If %r is specified, the %r should be specified also." % | |||||
| (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value)) | |||||
| if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]): | |||||
| raise ValidationError("No integer in 'bounds', please modify it.") | |||||
| # check bound and choice | |||||
| if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data): | |||||
| raise ValidationError("Only one of [%r, %r] should be specified." % | |||||
| (bound_key, choice_key)) | |||||
| class TargetSchema(Schema): | |||||
| """Schema for target.""" | |||||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||||
| group = fields.Str(error_messages=str_err_msg) | |||||
| name = fields.Str(required=True, error_messages=str_err_msg) | |||||
| goal = fields.Str(error_messages=str_err_msg) | |||||
| @validates("group") | |||||
| def check_group(self, group): | |||||
| """Check if bounds are valid.""" | |||||
| if group not in TargetGroup.list_members(): | |||||
| raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group)) | |||||
| @validates("goal") | |||||
| def check_goal(self, goal): | |||||
| """Check if source is valid.""" | |||||
| if goal not in TargetGoal.list_members(): | |||||
| raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal)) | |||||
| @validates_schema | |||||
| def check_combination(self, data, **kwargs): | |||||
| """check the combination of parameters.""" | |||||
| if TargetKey.GROUP.value not in data: | |||||
| # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. | |||||
| return | |||||
| name = data.get(TargetKey.NAME.value) | |||||
| group = data.get(TargetKey.GROUP.value) | |||||
| if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members(): | |||||
| raise ValidationError({ | |||||
| TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group}) | |||||
| class OptimizerConfig(Schema): | |||||
| """Define the search model condition parameter schema.""" | |||||
| dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") | |||||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||||
| summary_base_dir = fields.Str(required=True, error_messages=str_err_msg) | |||||
| command = fields.Str(required=True, error_messages=str_err_msg) | |||||
| tuner = fields.Dict(required=True, error_messages=dict_err_msg) | |||||
| target = fields.Dict(required=True, error_messages=dict_err_msg) | |||||
| parameters = fields.Dict(required=True, error_messages=dict_err_msg) | |||||
| @validates("tuner") | |||||
| def check_tuner(self, data): | |||||
| """Check tuner.""" | |||||
| err = TunerSchema().validate(data) | |||||
| if err: | |||||
| raise ValidationError(err) | |||||
| @validates("parameters") | |||||
| def check_parameters(self, parameters): | |||||
| """Check parameters.""" | |||||
| for name, value in parameters.items(): | |||||
| err = ParameterSchema().validate(value) | |||||
| if err: | |||||
| raise ValidationError({name: err}) | |||||
| if HyperParamKey.SOURCE.value not in value: | |||||
| # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. | |||||
| continue | |||||
| source = value.get(HyperParamKey.SOURCE.value) | |||||
| if source == HyperParamSource.SYSTEM_DEFINED.value and \ | |||||
| name not in TunableSystemDefinedParams.list_members(): | |||||
| raise ValidationError({ | |||||
| name: {"source": "This param is not system defined. Current source is: %s." % source}}) | |||||
| @validates("target") | |||||
| def check_target(self, target): | |||||
| """Check target.""" | |||||
| err = TargetSchema().validate(target) | |||||
| if err: | |||||
| raise ValidationError(err) | |||||
| @@ -15,30 +15,73 @@ | |||||
| """Hyper config.""" | """Hyper config.""" | ||||
| import json | import json | ||||
| import os | import os | ||||
| from attrdict import AttrDict | |||||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | |||||
| from mindinsight.optimizer.common.exceptions import HyperConfigError | |||||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT | |||||
| from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError | |||||
| _HYPER_CONFIG_LEN_LIMIT = 100000 | |||||
| class AttributeDict(dict): | |||||
| """A dict can be accessed by attribute.""" | |||||
| def __init__(self, d=None): | |||||
| super().__init__() | |||||
| if d is not None: | |||||
| for k, v in d.items(): | |||||
| self[k] = v | |||||
| class HyperConfig: | |||||
| """ | |||||
| Hyper config. | |||||
| def __key(self, key): | |||||
| """Get key.""" | |||||
| return "" if key is None else key | |||||
| def __setattr__(self, key, value): | |||||
| """Set attribute for object.""" | |||||
| self[self.__key(key)] = value | |||||
| def __getattr__(self, key): | |||||
| """ | |||||
| Get attribute value according by attribute name. | |||||
| Args: | |||||
| key (str): attribute name. | |||||
| Returns: | |||||
| Any, attribute value. | |||||
| Init hyper config: | |||||
| >>> hyper_config = HyperConfig() | |||||
| Raises: | |||||
| AttributeError: If the key does not exists, will raise Exception. | |||||
| Get suggest params: | |||||
| >>> param_obj = hyper_config.params | |||||
| >>> learning_rate = params.learning_rate | |||||
| """ | |||||
| value = self.get(self.__key(key)) | |||||
| if value is None: | |||||
| raise AttributeError("The attribute %r is not exist." % key) | |||||
| return value | |||||
| Get summary dir: | |||||
| >>> summary_dir = hyper_config.summary_dir | |||||
| def __getitem__(self, key): | |||||
| """Get attribute value according by attribute name.""" | |||||
| value = super().get(self.__key(key)) | |||||
| if value is None: | |||||
| raise AttributeError("The attribute %r is not exist." % key) | |||||
| return value | |||||
| Record by SummaryCollector: | |||||
| >>> summary_cb = SummaryCollector(summary_dir) | |||||
| def __setitem__(self, key, value): | |||||
| """Set attribute for object.""" | |||||
| return super().__setitem__(self.__key(key), value) | |||||
| class HyperConfig: | |||||
| """ | |||||
| Hyper config. | |||||
| 1. Init HyperConfig. | |||||
| 2. Get suggested params and summary_dir. | |||||
| 3. Record by SummaryCollector with summary_dir. | |||||
| Examples: | |||||
| >>> hyper_config = HyperConfig() | |||||
| >>> params = hyper_config.params | |||||
| >>> learning_rate = params.learning_rate | |||||
| >>> batch_size = params.batch_size | |||||
| >>> summary_dir = hyper_config.summary_dir | |||||
| >>> summary_cb = SummaryCollector(summary_dir) | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| self._init_validate_hyper_config() | self._init_validate_hyper_config() | ||||
| @@ -47,10 +90,10 @@ class HyperConfig: | |||||
| """Init and validate hyper config.""" | """Init and validate hyper config.""" | ||||
| hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) | hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) | ||||
| if hyper_config is None: | if hyper_config is None: | ||||
| raise HyperConfigError("Hyper config is not in system environment.") | |||||
| if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT: | |||||
| raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of " | |||||
| "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||||
| raise HyperConfigEnvError("Hyper config is not in system environment.") | |||||
| if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: | |||||
| raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of " | |||||
| "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||||
| try: | try: | ||||
| hyper_config = json.loads(hyper_config) | hyper_config = json.loads(hyper_config) | ||||
| @@ -60,8 +103,7 @@ class HyperConfig: | |||||
| raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) | raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) | ||||
| self._validate_hyper_config(hyper_config) | self._validate_hyper_config(hyper_config) | ||||
| self._summary_dir = hyper_config.get('summary_dir') | |||||
| self._param_obj = AttrDict(hyper_config.get('params')) | |||||
| self._hyper_config = hyper_config | |||||
| def _validate_hyper_config(self, hyper_config): | def _validate_hyper_config(self, hyper_config): | ||||
| """Validate hyper config.""" | """Validate hyper config.""" | ||||
| @@ -86,9 +128,13 @@ class HyperConfig: | |||||
| @property | @property | ||||
| def params(self): | def params(self): | ||||
| """Get params.""" | """Get params.""" | ||||
| return self._param_obj | |||||
| return AttributeDict(self._hyper_config.get('params')) | |||||
| @property | @property | ||||
| def summary_dir(self): | def summary_dir(self): | ||||
| """Get train summary dir path.""" | """Get train summary dir path.""" | ||||
| return self._summary_dir | |||||
| return self._hyper_config.get('summary_dir') | |||||
| @property | |||||
| def custom_lineage_data(self): | |||||
| return self._hyper_config.get('custom_lineage_data') | |||||
| @@ -22,16 +22,16 @@ import yaml | |||||
| from marshmallow import ValidationError | from marshmallow import ValidationError | ||||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | |||||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||||
| from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path | from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path | ||||
| from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX | |||||
| from mindinsight.lineagemgr.model import get_lineage_table, LineageTable | |||||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | ||||
| from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal | |||||
| from mindinsight.optimizer.common.exceptions import OptimizerTerminateError | |||||
| from mindinsight.optimizer.common.enums import TuneMethod | |||||
| from mindinsight.optimizer.common.exceptions import OptimizerTerminateError, ConfigParamError | |||||
| from mindinsight.optimizer.common.log import logger | from mindinsight.optimizer.common.log import logger | ||||
| from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig | |||||
| from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner | from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner | ||||
| from mindinsight.optimizer.utils.param_handler import organize_params_target | from mindinsight.optimizer.utils.param_handler import organize_params_target | ||||
| from mindinsight.optimizer.utils.utils import get_nested_message | |||||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError | from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError | ||||
| _OK = 0 | _OK = 0 | ||||
| @@ -51,7 +51,6 @@ class Tuner: | |||||
| def __init__(self, config_path: str): | def __init__(self, config_path: str): | ||||
| self._config_info = self._validate_config(config_path) | self._config_info = self._validate_config(config_path) | ||||
| self._summary_base_dir = self._config_info.get('summary_base_dir') | self._summary_base_dir = self._config_info.get('summary_base_dir') | ||||
| self._data_manager = self._init_data_manager() | |||||
| self._dir_prefix = 'train' | self._dir_prefix = 'train' | ||||
| def _validate_config(self, config_path): | def _validate_config(self, config_path): | ||||
| @@ -65,11 +64,17 @@ class Tuner: | |||||
| except Exception as exc: | except Exception as exc: | ||||
| raise UnknownError("Detail: %s." % str(exc)) | raise UnknownError("Detail: %s." % str(exc)) | ||||
| # need to add validation for config_info: command, summary_base_dir, target and params. | |||||
| self._validate_config_schema(config_info) | |||||
| config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) | config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) | ||||
| self._make_summary_base_dir(config_info['summary_base_dir']) | self._make_summary_base_dir(config_info['summary_base_dir']) | ||||
| return config_info | return config_info | ||||
| def _validate_config_schema(self, config_info): | |||||
| error = OptimizerConfig().validate(config_info) | |||||
| if error: | |||||
| err_msg = get_nested_message(error) | |||||
| raise ConfigParamError(err_msg) | |||||
| def _make_summary_base_dir(self, summary_base_dir): | def _make_summary_base_dir(self, summary_base_dir): | ||||
| """Check and make summary_base_dir.""" | """Check and make summary_base_dir.""" | ||||
| if not os.path.exists(summary_base_dir): | if not os.path.exists(summary_base_dir): | ||||
| @@ -82,13 +87,6 @@ class Tuner: | |||||
| except OSError as exc: | except OSError as exc: | ||||
| raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) | raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) | ||||
| def _init_data_manager(self): | |||||
| """Initialize data_manager.""" | |||||
| data_manager = DataManager(summary_base_dir=self._summary_base_dir) | |||||
| data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||||
| return data_manager | |||||
| def _normalize_path(self, param_name, path): | def _normalize_path(self, param_name, path): | ||||
| """Normalize config path.""" | """Normalize config path.""" | ||||
| path = os.path.realpath(path) | path = os.path.realpath(path) | ||||
| @@ -104,10 +102,8 @@ class Tuner: | |||||
| def _update_from_lineage(self): | def _update_from_lineage(self): | ||||
| """Update lineage from lineagemgr.""" | """Update lineage from lineagemgr.""" | ||||
| self._data_manager.start_load_data(reload_interval=0).join() | |||||
| try: | try: | ||||
| lineage_table = get_lineage_table(self._data_manager) | |||||
| lineage_table = get_lineage_table(summary_base_dir=self._summary_base_dir) | |||||
| except MindInsightException as err: | except MindInsightException as err: | ||||
| logger.info("Can not query lineage. Detail: %s", str(err)) | logger.info("Can not query lineage. Detail: %s", str(err)) | ||||
| lineage_table = None | lineage_table = None | ||||
| @@ -122,12 +118,14 @@ class Tuner: | |||||
| tuner = self._config_info.get('tuner') | tuner = self._config_info.get('tuner') | ||||
| for _ in range(max_expr_times): | for _ in range(max_expr_times): | ||||
| self._update_from_lineage() | self._update_from_lineage() | ||||
| suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name")) | |||||
| suggestion, user_defined_info = self._suggest(self._lineage_table, params_info, target_info, tuner) | |||||
| hyper_config = { | hyper_config = { | ||||
| 'params': suggestion, | 'params': suggestion, | ||||
| 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}') | |||||
| 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}'), | |||||
| 'custom_lineage_data': user_defined_info | |||||
| } | } | ||||
| logger.info("Suggest values are: %s.", suggestion) | |||||
| os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) | os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) | ||||
| s = subprocess.Popen(shlex.split(command)) | s = subprocess.Popen(shlex.split(command)) | ||||
| s.wait() | s.wait() | ||||
| @@ -135,29 +133,26 @@ class Tuner: | |||||
| logger.error("An error occurred during execution, the auto tuning will be terminated.") | logger.error("An error occurred during execution, the auto tuning will be terminated.") | ||||
| raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") | raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") | ||||
| def _get_tuner(self, tune_method=TuneMethod.GP.value): | |||||
| def _get_tuner(self, tuner): | |||||
| """Get tuner.""" | """Get tuner.""" | ||||
| if tune_method.lower() not in TuneMethod.list_members(): | |||||
| if tuner is None: | |||||
| return GPBaseTuner() | |||||
| tuner_name = tuner.get("name").lower() | |||||
| if tuner_name not in TuneMethod.list_members(): | |||||
| raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) | raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) | ||||
| args = tuner.get("args") | |||||
| if args is not None and args.get("method") is not None: | |||||
| return GPBaseTuner(args.get("method")) | |||||
| # Only support gaussian process regressor currently. | # Only support gaussian process regressor currently. | ||||
| return GPBaseTuner() | return GPBaseTuner() | ||||
| def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method): | |||||
| def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, tuner): | |||||
| """Get suggestions for targets.""" | """Get suggestions for targets.""" | ||||
| tuner = self._get_tuner(method) | |||||
| target_name = target_info[TargetKey.NAME.value] | |||||
| if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric': | |||||
| target_name = METRIC_PREFIX + target_name | |||||
| param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name) | |||||
| if not param_matrix.empty: | |||||
| suggestion = tuner.suggest([], [], params_info) | |||||
| else: | |||||
| target_column = target_matrix[target_name].reshape((-1, 1)) | |||||
| if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: | |||||
| target_column = -target_column | |||||
| suggestion = tuner.suggest(param_matrix, target_column, params_info) | |||||
| return suggestion | |||||
| tuner = self._get_tuner(tuner) | |||||
| param_matrix, target_column = organize_params_target(lineage_table, params_info, target_info) | |||||
| suggestion, user_defined_info = tuner.suggest(param_matrix, target_column, params_info) | |||||
| return suggestion, user_defined_info | |||||
| @@ -22,10 +22,11 @@ from sklearn.gaussian_process import GaussianProcessRegressor | |||||
| from sklearn.gaussian_process.kernels import Matern | from sklearn.gaussian_process.kernels import Matern | ||||
| from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey | from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey | ||||
| from mindinsight.optimizer.common.log import logger | |||||
| from mindinsight.optimizer.tuners.base_tuner import BaseTuner | |||||
| from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type | from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type | ||||
| from mindinsight.optimizer.utils.transformer import Transformer | from mindinsight.optimizer.utils.transformer import Transformer | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from mindinsight.optimizer.tuners.base_tuner import BaseTuner | |||||
| class AcquisitionFunction: | class AcquisitionFunction: | ||||
| @@ -141,8 +142,10 @@ class GPBaseTuner(BaseTuner): | |||||
| x_seeds = generate_arrays(params_info, n_iter) | x_seeds = generate_arrays(params_info, n_iter) | ||||
| for x_try in x_seeds: | for x_try in x_seeds: | ||||
| res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), | |||||
| x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") | |||||
| with warnings.catch_warnings(): | |||||
| warnings.simplefilter("ignore") | |||||
| res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), | |||||
| x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") | |||||
| if not res.success: | if not res.success: | ||||
| continue | continue | ||||
| @@ -164,6 +167,8 @@ class GPBaseTuner(BaseTuner): | |||||
| min_lineage_rows = 2 | min_lineage_rows = 2 | ||||
| if not np.array(params).any() or params.shape[0] < min_lineage_rows: | if not np.array(params).any() or params.shape[0] < min_lineage_rows: | ||||
| logger.info("Without valid histories or the rows of lineages < %s, " | |||||
| "parameters will be recommended randomly.", min_lineage_rows) | |||||
| suggestion = generate_arrays(params_info) | suggestion = generate_arrays(params_info) | ||||
| else: | else: | ||||
| self._gp.fit(params, target) | self._gp.fit(params, target) | ||||
| @@ -174,5 +179,5 @@ class GPBaseTuner(BaseTuner): | |||||
| params_info=params_info | params_info=params_info | ||||
| ) | ) | ||||
| suggestion = Transformer.transform_list_to_dict(params_info, suggestion) | |||||
| return suggestion | |||||
| suggestion, user_defined_info = Transformer.transform_list_to_dict(params_info, suggestion) | |||||
| return suggestion, user_defined_info | |||||
| @@ -14,8 +14,9 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Utils for params.""" | """Utils for params.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.lineagemgr.model import LineageTable | |||||
| from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType | |||||
| from mindinsight.lineagemgr.model import LineageTable, USER_DEFINED_PREFIX, METRIC_PREFIX | |||||
| from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType, HyperParamSource, TargetKey, \ | |||||
| TargetGoal, TunableSystemDefinedParams, TargetGroup, SystemDefinedTargets | |||||
| from mindinsight.optimizer.common.log import logger | from mindinsight.optimizer.common.log import logger | ||||
| @@ -54,7 +55,6 @@ def match_value_type(array, params_info: dict): | |||||
| array_new = [] | array_new = [] | ||||
| index = 0 | index = 0 | ||||
| for _, param_info in params_info.items(): | for _, param_info in params_info.items(): | ||||
| param_type = param_info[HyperParamKey.TYPE.value] | |||||
| value = array[index] | value = array[index] | ||||
| if HyperParamKey.BOUND.value in param_info: | if HyperParamKey.BOUND.value in param_info: | ||||
| bound = param_info[HyperParamKey.BOUND.value] | bound = param_info[HyperParamKey.BOUND.value] | ||||
| @@ -64,7 +64,7 @@ def match_value_type(array, params_info: dict): | |||||
| choices = param_info[HyperParamKey.CHOICE.value] | choices = param_info[HyperParamKey.CHOICE.value] | ||||
| nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) | nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) | ||||
| value = choices[nearest_index] | value = choices[nearest_index] | ||||
| if param_type == HyperParamType.INT.value: | |||||
| if param_info.get(HyperParamKey.TYPE.value) == HyperParamType.INT.value: | |||||
| value = int(value) | value = int(value) | ||||
| if HyperParamKey.DECIMAL.value in param_info: | if HyperParamKey.DECIMAL.value in param_info: | ||||
| value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) | value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) | ||||
| @@ -73,20 +73,51 @@ def match_value_type(array, params_info: dict): | |||||
| return array_new | return array_new | ||||
| def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name): | |||||
| def organize_params_target(lineage_table: LineageTable, params_info: dict, target_info): | |||||
| """Organize params and target.""" | """Organize params and target.""" | ||||
| empty_result = np.array([]) | empty_result = np.array([]) | ||||
| if lineage_table is None: | if lineage_table is None: | ||||
| return empty_result, empty_result | return empty_result, empty_result | ||||
| param_keys = list(params_info.keys()) | |||||
| param_keys = [] | |||||
| for param_key, param_info in params_info.items(): | |||||
| # It will be a user_defined param: | |||||
| # 1. if 'source' is specified as 'user_defined' | |||||
| # 2. if 'source' is not specified and the param is not a system_defined key | |||||
| source = param_info.get(HyperParamKey.SOURCE.value) | |||||
| prefix = _get_prefix(param_key, source, HyperParamSource.USER_DEFINED.value, | |||||
| USER_DEFINED_PREFIX, TunableSystemDefinedParams.list_members()) | |||||
| param_key = f'{prefix}{param_key}' | |||||
| if prefix == USER_DEFINED_PREFIX: | |||||
| param_info[HyperParamKey.SOURCE.value] = HyperParamSource.USER_DEFINED.value | |||||
| else: | |||||
| param_info[HyperParamKey.SOURCE.value] = HyperParamSource.SYSTEM_DEFINED.value | |||||
| param_keys.append(param_key) | |||||
| target_name = target_info[TargetKey.NAME.value] | |||||
| group = target_info.get(TargetKey.GROUP.value) | |||||
| prefix = _get_prefix(target_name, group, TargetGroup.METRIC.value, | |||||
| METRIC_PREFIX, SystemDefinedTargets.list_members()) | |||||
| target_name = prefix + target_name | |||||
| lineage_df = lineage_table.dataframe_data | lineage_df = lineage_table.dataframe_data | ||||
| try: | try: | ||||
| lineage_df = lineage_df[param_keys + [target_name]] | lineage_df = lineage_df[param_keys + [target_name]] | ||||
| lineage_df = lineage_df.dropna(axis=0, how='any') | lineage_df = lineage_df.dropna(axis=0, how='any') | ||||
| return lineage_df[param_keys], lineage_df[target_name] | |||||
| target_column = np.array(lineage_df[target_name]) | |||||
| if TargetKey.GOAL.value in target_info and \ | |||||
| target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: | |||||
| target_column = -target_column | |||||
| return np.array(lineage_df[param_keys]), target_column | |||||
| except KeyError as exc: | except KeyError as exc: | ||||
| logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." | logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." | ||||
| "Detail: %s.", str(exc)) | "Detail: %s.", str(exc)) | ||||
| return empty_result, empty_result | return empty_result, empty_result | ||||
| def _get_prefix(name, field, other_defined_field, other_defined_prefix, system_defined_fields): | |||||
| if (field == other_defined_field) or (field is None and name not in system_defined_fields): | |||||
| return other_defined_prefix | |||||
| return '' | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Transformer.""" | """Transformer.""" | ||||
| from mindinsight.optimizer.utils.param_handler import match_value_type | from mindinsight.optimizer.utils.param_handler import match_value_type | ||||
| from mindinsight.optimizer.common.enums import HyperParamSource, HyperParamKey | |||||
| class Transformer: | class Transformer: | ||||
| @@ -23,7 +24,12 @@ class Transformer: | |||||
| """Transform from tuner.""" | """Transform from tuner.""" | ||||
| suggest_list = match_value_type(suggest_list, params_info) | suggest_list = match_value_type(suggest_list, params_info) | ||||
| param_dict = {} | param_dict = {} | ||||
| user_defined_info = {} | |||||
| for index, param_name in enumerate(params_info): | for index, param_name in enumerate(params_info): | ||||
| param_dict.update({param_name: suggest_list[index]}) | |||||
| param_item = {param_name: suggest_list[index]} | |||||
| param_dict.update(param_item) | |||||
| source = params_info.get(param_name).get(HyperParamKey.SOURCE.value) | |||||
| if source is not None and source == HyperParamSource.USER_DEFINED.value: | |||||
| user_defined_info.update(param_item) | |||||
| return param_dict | |||||
| return param_dict, user_defined_info | |||||
| @@ -68,3 +68,18 @@ def is_simple_numpy_number(dtype): | |||||
| return True | return True | ||||
| return False | return False | ||||
| def get_nested_message(info: dict, out_err_msg=""): | |||||
| """Get error message from the error dict generated by schema validation.""" | |||||
| if not isinstance(info, dict): | |||||
| if isinstance(info, list): | |||||
| info = info[0] | |||||
| return f'Error in {out_err_msg}: {info}' | |||||
| for key in info: | |||||
| if isinstance(key, str) and key != '_schema': | |||||
| if out_err_msg: | |||||
| out_err_msg = f'{out_err_msg}.{key}' | |||||
| else: | |||||
| out_err_msg = key | |||||
| return get_nested_message(info[key], out_err_msg) | |||||
| @@ -98,3 +98,5 @@ class OptimizerErrors(Enum): | |||||
| CORRELATION_NAN = 2 | CORRELATION_NAN = 2 | ||||
| HYPER_CONFIG_ERROR = 3 | HYPER_CONFIG_ERROR = 3 | ||||
| OPTIMIZER_TERMINATE = 4 | OPTIMIZER_TERMINATE = 4 | ||||
| CONFIG_PARAM_ERROR = 5 | |||||
| HYPER_CONFIG_ENV_ERROR = 6 | |||||
| @@ -10,9 +10,9 @@ marshmallow>=2.19.2 | |||||
| numpy>=1.17.0 | numpy>=1.17.0 | ||||
| protobuf>=3.8.0 | protobuf>=3.8.0 | ||||
| psutil>=5.6.1 | psutil>=5.6.1 | ||||
| pyyaml>=5.3 | |||||
| pyyaml>=5.3.1 | |||||
| scipy>=1.3.3 | scipy>=1.3.3 | ||||
| scikit-learn>=0.23.1 | |||||
| scikit-learn>=0.21.2 | |||||
| six>=1.12.0 | six>=1.12.0 | ||||
| Werkzeug>=1.0.0 | Werkzeug>=1.0.0 | ||||
| pandas>=1.0.4 | pandas>=1.0.4 | ||||
| @@ -208,6 +208,7 @@ if __name__ == '__main__': | |||||
| 'mindinsight=mindinsight.utils.command:main', | 'mindinsight=mindinsight.utils.command:main', | ||||
| 'mindconverter=mindinsight.mindconverter.cli:cli_entry', | 'mindconverter=mindinsight.mindconverter.cli:cli_entry', | ||||
| 'mindwizard=mindinsight.wizard.cli:cli_entry', | 'mindwizard=mindinsight.wizard.cli:cli_entry', | ||||
| 'mindoptimizer=mindinsight.optimizer.cli:cli_entry', | |||||
| ], | ], | ||||
| }, | }, | ||||
| python_requires='>=3.7', | python_requires='>=3.7', | ||||
| @@ -98,14 +98,14 @@ class TestModelLineage(TestCase): | |||||
| train_callback.end(RunContext(self.run_context)) | train_callback.end(RunContext(self.run_context)) | ||||
| LINEAGE_DATA_MANAGER.start_load_data().join() | LINEAGE_DATA_MANAGER.start_load_data().join() | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||||
| assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 | assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 | ||||
| run_context = self.run_context | run_context = self.run_context | ||||
| run_context['epoch_num'] = 14 | run_context['epoch_num'] = 14 | ||||
| train_callback.end(RunContext(run_context)) | train_callback.end(RunContext(run_context)) | ||||
| LINEAGE_DATA_MANAGER.start_load_data().join() | LINEAGE_DATA_MANAGER.start_load_data().join() | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||||
| assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 | assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 | ||||
| @pytest.mark.scene_eval(3) | @pytest.mark.scene_eval(3) | ||||
| @@ -198,7 +198,7 @@ class TestModelLineage(TestCase): | |||||
| train_callback.end(RunContext(run_context_customized)) | train_callback.end(RunContext(run_context_customized)) | ||||
| LINEAGE_DATA_MANAGER.start_load_data().join() | LINEAGE_DATA_MANAGER.start_load_data().join() | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||||
| assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ | assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ | ||||
| == 'SoftmaxCrossEntropyWithLogits' | == 'SoftmaxCrossEntropyWithLogits' | ||||
| assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' | assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' | ||||
| @@ -190,7 +190,7 @@ class TestModelApi(TestCase): | |||||
| search_condition = { | search_condition = { | ||||
| 'sorted_name': 'summary_dir' | 'sorted_name': 'summary_dir' | ||||
| } | } | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||||
| expect_objects = expect_result.get('object') | expect_objects = expect_result.get('object') | ||||
| for idx, res_object in enumerate(res.get('object')): | for idx, res_object in enumerate(res.get('object')): | ||||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | ||||
| @@ -228,7 +228,7 @@ class TestModelApi(TestCase): | |||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||||
| partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||||
| expect_objects = expect_result.get('object') | expect_objects = expect_result.get('object') | ||||
| for idx, res_object in enumerate(partial_res.get('object')): | for idx, res_object in enumerate(partial_res.get('object')): | ||||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | ||||
| @@ -266,7 +266,7 @@ class TestModelApi(TestCase): | |||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||||
| partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||||
| expect_objects = expect_result.get('object') | expect_objects = expect_result.get('object') | ||||
| for idx, res_object in enumerate(partial_res.get('object')): | for idx, res_object in enumerate(partial_res.get('object')): | ||||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | ||||
| @@ -295,7 +295,7 @@ class TestModelApi(TestCase): | |||||
| ], | ], | ||||
| 'count': 3 | 'count': 3 | ||||
| } | } | ||||
| partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) | |||||
| partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) | |||||
| expect_objects = expect_result.get('object') | expect_objects = expect_result.get('object') | ||||
| for idx, res_object in enumerate(partial_res1.get('object')): | for idx, res_object in enumerate(partial_res1.get('object')): | ||||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | ||||
| @@ -314,7 +314,7 @@ class TestModelApi(TestCase): | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 0 | 'count': 0 | ||||
| } | } | ||||
| partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) | |||||
| partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) | |||||
| assert expect_result == partial_res2 | assert expect_result == partial_res2 | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -335,7 +335,7 @@ class TestModelApi(TestCase): | |||||
| 'eq': self._empty_train_id | 'eq': self._empty_train_id | ||||
| } | } | ||||
| } | } | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||||
| assert expect_result == res | assert expect_result == res | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -366,7 +366,7 @@ class TestModelApi(TestCase): | |||||
| ], | ], | ||||
| 'count': 1 | 'count': 1 | ||||
| } | } | ||||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||||
| assert expect_result == res | assert expect_result == res | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -386,6 +386,7 @@ class TestModelApi(TestCase): | |||||
| 'The search_condition element summary_dir should be dict.', | 'The search_condition element summary_dir should be dict.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -398,6 +399,7 @@ class TestModelApi(TestCase): | |||||
| 'The sorted_name must exist when sorted_type exists.', | 'The sorted_name must exist when sorted_type exists.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -408,6 +410,7 @@ class TestModelApi(TestCase): | |||||
| 'Invalid search_condition type, it should be dict.', | 'Invalid search_condition type, it should be dict.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -420,6 +423,7 @@ class TestModelApi(TestCase): | |||||
| 'The limit must be int.', | 'The limit must be int.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -440,6 +444,7 @@ class TestModelApi(TestCase): | |||||
| 'The offset must be int.', | 'The offset must be int.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -454,6 +459,7 @@ class TestModelApi(TestCase): | |||||
| 'The search attribute not supported.', | 'The search attribute not supported.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -475,6 +481,7 @@ class TestModelApi(TestCase): | |||||
| 'The sorted_type must be ascending or descending', | 'The sorted_type must be ascending or descending', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -489,6 +496,7 @@ class TestModelApi(TestCase): | |||||
| 'The compare condition should be in', | 'The compare condition should be in', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -503,6 +511,7 @@ class TestModelApi(TestCase): | |||||
| 'The parameter metric/accuracy is invalid.', | 'The parameter metric/accuracy is invalid.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -526,7 +535,7 @@ class TestModelApi(TestCase): | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 0 | 'count': 0 | ||||
| } | } | ||||
| partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) | |||||
| partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) | |||||
| assert expect_result == partial_res1 | assert expect_result == partial_res1 | ||||
| # the (offset + 1) * limit > count | # the (offset + 1) * limit > count | ||||
| @@ -542,7 +551,7 @@ class TestModelApi(TestCase): | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) | |||||
| partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) | |||||
| assert expect_result == partial_res2 | assert expect_result == partial_res2 | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -566,6 +575,7 @@ class TestModelApi(TestCase): | |||||
| f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.', | f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -589,6 +599,7 @@ class TestModelApi(TestCase): | |||||
| "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", | "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -610,6 +621,7 @@ class TestModelApi(TestCase): | |||||
| 'The sorted_name must be in', | 'The sorted_name must be in', | ||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| LINEAGE_DATA_MANAGER, | LINEAGE_DATA_MANAGER, | ||||
| None, | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @@ -23,7 +23,7 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError | |||||
| class TestValidateSearchModelCondition(TestCase): | class TestValidateSearchModelCondition(TestCase): | ||||
| """Test the mothod of validate_search_model_condition.""" | |||||
| """Test the method of validate_search_model_condition.""" | |||||
| def test_validate_search_model_condition_param_type_error(self): | def test_validate_search_model_condition_param_type_error(self): | ||||
| """Test the method of validate_search_model_condition with LineageParamTypeError.""" | """Test the method of validate_search_model_condition with LineageParamTypeError.""" | ||||
| condition = { | condition = { | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test common module.""" | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test validator.""" | |||||
| @@ -0,0 +1,161 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test optimizer config schema.""" | |||||
| from copy import deepcopy | |||||
| from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig | |||||
| from mindinsight.optimizer.common.enums import TargetGroup, HyperParamSource | |||||
| _BASE_DATA = { | |||||
| 'command': 'python ./test.py', | |||||
| 'summary_base_dir': './demo_lineage_r0.3', | |||||
| 'tuner': { | |||||
| 'name': 'gp', | |||||
| 'args': { | |||||
| 'method': 'ucb' | |||||
| } | |||||
| }, | |||||
| 'target': { | |||||
| 'group': 'metric', | |||||
| 'name': 'Accuracy', | |||||
| 'goal': 'maximize' | |||||
| }, | |||||
| 'parameters': { | |||||
| 'learning_rate': { | |||||
| 'bounds': [0.0001, 0.001], | |||||
| 'type': 'float' | |||||
| }, | |||||
| 'batch_size': { | |||||
| 'choice': [32, 64, 128, 256], | |||||
| 'type': 'int' | |||||
| }, | |||||
| 'decay_step': { | |||||
| 'choice': [20], | |||||
| 'type': 'int' | |||||
| } | |||||
| } | |||||
| } | |||||
| class TestOptimizerConfig: | |||||
| """Test the method of validate_search_model_condition.""" | |||||
| _config_dict = dict(_BASE_DATA) | |||||
| def test_config_dict_with_wrong_type(self): | |||||
| """Test config dict with wrong type.""" | |||||
| config_dict = deepcopy(self._config_dict) | |||||
| init_list = ['a'] | |||||
| init_str = 'a' | |||||
| config_dict['command'] = init_list | |||||
| config_dict['summary_base_dir'] = init_list | |||||
| config_dict['target']['name'] = init_list | |||||
| config_dict['target']['goal'] = init_list | |||||
| config_dict['parameters']['learning_rate']['bounds'] = init_str | |||||
| config_dict['parameters']['learning_rate']['choice'] = init_str | |||||
| expected_err = { | |||||
| 'command': ["Value type should be 'str'."], | |||||
| 'parameters': { | |||||
| 'learning_rate': { | |||||
| 'bounds': ["Value type should be 'list'."], | |||||
| 'choice': ["Value type should be 'list'."] | |||||
| } | |||||
| }, | |||||
| 'summary_base_dir': ["Value type should be 'str'."], | |||||
| 'target': { | |||||
| 'name': ["Value type should be 'str'."], | |||||
| 'goal': ["Value type should be 'str'."] | |||||
| } | |||||
| } | |||||
| err = OptimizerConfig().validate(config_dict) | |||||
| assert expected_err == err | |||||
| def test_config_dict_with_wrong_value(self): | |||||
| """Test config dict with wrong value.""" | |||||
| config_dict = deepcopy(self._config_dict) | |||||
| init_list = ['a'] | |||||
| init_str = 'a' | |||||
| config_dict['target']['group'] = init_str | |||||
| config_dict['target']['goal'] = init_str | |||||
| config_dict['tuner']['name'] = init_str | |||||
| config_dict['parameters']['learning_rate']['bounds'] = init_list | |||||
| config_dict['parameters']['learning_rate']['choice'] = init_list | |||||
| config_dict['parameters']['learning_rate']['type'] = init_str | |||||
| expected_err = { | |||||
| 'parameters': { | |||||
| 'learning_rate': { | |||||
| 'bounds': { | |||||
| 0: ['Value(s) should be integer or float.'] | |||||
| }, | |||||
| 'choice': { | |||||
| 0: ['Value(s) should be integer or float.'] | |||||
| }, | |||||
| 'type': ["The type should be in ['int', 'float']."] | |||||
| } | |||||
| }, | |||||
| 'target': { | |||||
| 'goal': ["Value should be in ['maximize', 'minimize']. Current value is a."], | |||||
| 'group': ["Value should be in ['system_defined', 'metric']. Current value is a."]}, | |||||
| 'tuner': { | |||||
| 'name': ['Must be one of: gp.'] | |||||
| } | |||||
| } | |||||
| err = OptimizerConfig().validate(config_dict) | |||||
| assert expected_err == err | |||||
| def test_target_combination(self): | |||||
| """Test target combination.""" | |||||
| config_dict = deepcopy(self._config_dict) | |||||
| config_dict['target']['group'] = TargetGroup.SYSTEM_DEFINED.value | |||||
| config_dict['target']['name'] = 'a' | |||||
| expected_err = { | |||||
| 'target': { | |||||
| 'group': 'This target is not system defined. Current group is: system_defined.' | |||||
| } | |||||
| } | |||||
| err = OptimizerConfig().validate(config_dict) | |||||
| assert expected_err == err | |||||
| def test_parameters_combination1(self): | |||||
| """Test parameters combination.""" | |||||
| config_dict = deepcopy(self._config_dict) | |||||
| config_dict['parameters']['decay_step']['source'] = HyperParamSource.SYSTEM_DEFINED.value | |||||
| expected_err = { | |||||
| 'parameters': { | |||||
| 'decay_step': { | |||||
| 'source': 'This param is not system defined. Current source is: system_defined.' | |||||
| } | |||||
| } | |||||
| } | |||||
| err = OptimizerConfig().validate(config_dict) | |||||
| assert expected_err == err | |||||
| def test_parameters_combination2(self): | |||||
| """Test parameters combination.""" | |||||
| config_dict = deepcopy(self._config_dict) | |||||
| config_dict['parameters']['decay_step']['bounds'] = [1, 40] | |||||
| expected_err = { | |||||
| 'parameters': { | |||||
| 'decay_step': { | |||||
| '_schema': ["Only one of ['bounds', 'choice'] should be specified."] | |||||
| } | |||||
| } | |||||
| } | |||||
| err = OptimizerConfig().validate(config_dict) | |||||
| assert expected_err == err | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test utils.""" | |||||