1. add cli entry 2. add more validation for hyper config 3. add validator for config dict 4. add st for optimizer config validator 5. add custom lineage for hyper configtags/v1.1.0
| @@ -43,7 +43,7 @@ def get_optimize_targets(): | |||
| def _get_optimize_targets(data_manager, search_condition=None): | |||
| """Get optimize targets.""" | |||
| table = get_lineage_table(data_manager, search_condition) | |||
| table = get_lineage_table(data_manager=data_manager, search_condition=search_condition) | |||
| target_summaries = [] | |||
| for target in table.target_names: | |||
| @@ -15,6 +15,7 @@ | |||
| """This file is used to parse lineage info.""" | |||
| import os | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryAnalyzeException, \ | |||
| LineageEventNotExistException, LineageEventFieldNotExistException, LineageFileNotFoundError, \ | |||
| MindInsightException | |||
| @@ -177,13 +178,49 @@ class LineageParser: | |||
| class LineageOrganizer: | |||
| """Lineage organizer.""" | |||
| def __init__(self, data_manager): | |||
| def __init__(self, data_manager=None, summary_base_dir=None): | |||
| self._data_manager = data_manager | |||
| self._summary_base_dir = summary_base_dir | |||
| self._check_params() | |||
| self._super_lineage_objs = {} | |||
| self._organize_from_cache() | |||
| self._organize_from_disk() | |||
| def _check_params(self): | |||
| """Check params.""" | |||
| if self._data_manager is not None and self._summary_base_dir is not None: | |||
| self._summary_base_dir = None | |||
| def _organize_from_disk(self): | |||
| """Organize lineage objs from disk.""" | |||
| if self._summary_base_dir is None: | |||
| return | |||
| summary_watcher = SummaryWatcher() | |||
| relative_dirs = summary_watcher.list_summary_directories( | |||
| summary_base_dir=self._summary_base_dir | |||
| ) | |||
| no_lineage_count = 0 | |||
| for item in relative_dirs: | |||
| relative_dir = item.get('relative_path') | |||
| update_time = item.get('update_time') | |||
| abs_summary_dir = os.path.realpath(os.path.join(self._summary_base_dir, relative_dir)) | |||
| try: | |||
| lineage_parser = LineageParser(relative_dir, abs_summary_dir, update_time) | |||
| super_lineage_obj = lineage_parser.super_lineage_obj | |||
| if super_lineage_obj is not None: | |||
| self._super_lineage_objs.update({abs_summary_dir: super_lineage_obj}) | |||
| except LineageFileNotFoundError: | |||
| no_lineage_count += 1 | |||
| if no_lineage_count == len(relative_dirs): | |||
| logger.info('There is no summary log file under summary_base_dir.') | |||
| def _organize_from_cache(self): | |||
| """Organize lineage objs from cache.""" | |||
| if self._data_manager is None: | |||
| return | |||
| brief_cache = self._data_manager.get_brief_cache() | |||
| cache_items = brief_cache.cache_items | |||
| for relative_dir, cache_train_job in cache_items.items(): | |||
| @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySumm | |||
| from mindinsight.lineagemgr.common.log import logger as log | |||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | |||
| from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition, validate_condition | |||
| from mindinsight.lineagemgr.common.validator.validate_path import validate_and_normalize_path | |||
| from mindinsight.lineagemgr.lineage_parser import LineageOrganizer | |||
| from mindinsight.lineagemgr.querier.querier import Querier | |||
| from mindinsight.optimizer.common.enums import ReasonCode | |||
| @@ -33,7 +34,7 @@ USER_DEFINED_PREFIX = "[U]" | |||
| USER_DEFINED_INFO_LIMIT = 100 | |||
| def filter_summary_lineage(data_manager, search_condition=None): | |||
| def filter_summary_lineage(data_manager=None, summary_base_dir=None, search_condition=None): | |||
| """ | |||
| Filter summary lineage from data_manager or parsing from summaries. | |||
| @@ -43,8 +44,18 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||
| Args: | |||
| data_manager (DataManager): Data manager defined as | |||
| mindinsight.datavisual.data_transform.data_manager.DataManager | |||
| summary_base_dir (str): The summary base directory. It contains summary | |||
| directories generated by training. | |||
| search_condition (dict): The search condition. | |||
| """ | |||
| if data_manager is None and summary_base_dir is None: | |||
| raise LineageParamTypeError("One of data_manager or summary_base_dir needs to be specified.") | |||
| if data_manager is None: | |||
| summary_base_dir = validate_and_normalize_path(summary_base_dir, 'summary_base_dir') | |||
| else: | |||
| summary_base_dir = data_manager.summary_base_dir | |||
| search_condition = {} if search_condition is None else search_condition | |||
| try: | |||
| @@ -56,7 +67,7 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||
| raise LineageSearchConditionParamError(str(error.message)) | |||
| try: | |||
| lineage_objects = LineageOrganizer(data_manager).super_lineage_objs | |||
| lineage_objects = LineageOrganizer(data_manager, summary_base_dir).super_lineage_objs | |||
| result = Querier(lineage_objects).filter_summary_lineage(condition=search_condition) | |||
| except LineageSummaryParseException: | |||
| result = {'object': [], 'count': 0} | |||
| @@ -68,12 +79,13 @@ def filter_summary_lineage(data_manager, search_condition=None): | |||
| return result | |||
| def get_flattened_lineage(data_manager, search_condition=None): | |||
| def get_flattened_lineage(data_manager=None, summary_base_dir=None, search_condition=None): | |||
| """ | |||
| Get lineage data in a table from data manager. | |||
| Args: | |||
| data_manager (mindinsight.datavisual.data_manager.DataManager): An object to manage loading. | |||
| summary_base_dir (str): The base directory for train jobs. | |||
| search_condition (dict): The search condition. | |||
| Returns: | |||
| @@ -81,7 +93,7 @@ def get_flattened_lineage(data_manager, search_condition=None): | |||
| """ | |||
| flatten_dict, user_count = {'train_id': []}, 0 | |||
| lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", []) | |||
| lineages = filter_summary_lineage(data_manager, summary_base_dir, search_condition).get("object", []) | |||
| for index, lineage in enumerate(lineages): | |||
| flatten_dict['train_id'].append(lineage.get("summary_dir")) | |||
| for key, val in _flatten_lineage(lineage.get('model_lineage', {})): | |||
| @@ -222,9 +234,9 @@ class LineageTable: | |||
| return self._drop_columns_info | |||
| def get_lineage_table(data_manager, search_condition=None): | |||
| def get_lineage_table(data_manager=None, summary_base_dir=None, search_condition=None): | |||
| """Get lineage table from data_manager.""" | |||
| lineage_table = get_flattened_lineage(data_manager, search_condition) | |||
| lineage_table = get_flattened_lineage(data_manager, summary_base_dir, search_condition) | |||
| lineage_table = LineageTable(pd.DataFrame(lineage_table)) | |||
| return lineage_table | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Command module.""" | |||
| import argparse | |||
| import os | |||
| import sys | |||
| import mindinsight | |||
| from mindinsight.optimizer.tuner import Tuner | |||
| class ConfigAction(argparse.Action): | |||
| """Summary base dir action class definition.""" | |||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from argparse.Action. | |||
| Args: | |||
| parser_in (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Option string for specific argument name. | |||
| """ | |||
| config_path = os.path.realpath(values) | |||
| if not os.path.exists(config_path): | |||
| parser_in.error(f'{option_string} {config_path} not exists.') | |||
| setattr(namespace, self.dest, config_path) | |||
| class IterAction(argparse.Action): | |||
| """Summary base dir action class definition.""" | |||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from argparse.Action. | |||
| Args: | |||
| parser_in (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Option string for specific argument name. | |||
| """ | |||
| iter_times = values | |||
| if iter_times <= 0: | |||
| parser_in.error(f'{option_string} {iter_times} should be a positive integer.') | |||
| setattr(namespace, self.dest, iter_times) | |||
| parser = argparse.ArgumentParser( | |||
| prog='mindoptimizer', | |||
| description='MindOptimizer CLI entry point (version: {})'.format(mindinsight.__version__)) | |||
| parser.add_argument( | |||
| '--version', | |||
| action='version', | |||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||
| parser.add_argument( | |||
| '--config', | |||
| type=str, | |||
| action=ConfigAction, | |||
| required=True, | |||
| default=os.path.join(os.getcwd(), 'output'), | |||
| help="Specify path for config file." | |||
| ) | |||
| parser.add_argument( | |||
| '--iter', | |||
| type=int, | |||
| action=IterAction, | |||
| default=1, | |||
| help="Optional, specify run times for the command in config file." | |||
| ) | |||
| def cli_entry(): | |||
| """Cli entry.""" | |||
| argv = sys.argv[1:] | |||
| if not argv: | |||
| argv = ['-h'] | |||
| args = parser.parse_args(argv) | |||
| else: | |||
| args = parser.parse_args() | |||
| tuner = Tuner(args.config) | |||
| tuner.optimize(max_expr_times=args.iter) | |||
| @@ -14,3 +14,4 @@ | |||
| # ============================================================================ | |||
| """Common constants for optimizer.""" | |||
| HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" | |||
| HYPER_CONFIG_LEN_LIMIT = 100000 | |||
| @@ -48,12 +48,17 @@ class TuneMethod(BaseEnum): | |||
| GP = 'gp' | |||
| class GPSupportArgs(BaseEnum): | |||
| METHOD = 'method' | |||
| class HyperParamKey(BaseEnum): | |||
| """Config keys for hyper parameters.""" | |||
| BOUND = 'bounds' | |||
| CHOICE = 'choice' | |||
| DECIMAL = 'decimal' | |||
| TYPE = 'type' | |||
| SOURCE = 'source' | |||
| class HyperParamType(BaseEnum): | |||
| @@ -73,3 +78,25 @@ class TargetGoal(BaseEnum): | |||
| """Goal for target.""" | |||
| MAXIMUM = 'maximize' | |||
| MINIMUM = 'minimize' | |||
| class HyperParamSource(BaseEnum): | |||
| SYSTEM_DEFINED = 'system_defined' | |||
| USER_DEFINED = 'user_defined' | |||
| class TargetGroup(BaseEnum): | |||
| SYSTEM_DEFINED = 'system_defined' | |||
| METRIC = 'metric' | |||
| class TunableSystemDefinedParams(BaseEnum): | |||
| """Tunable metadata keys of lineage collection.""" | |||
| BATCH_SIZE = 'batch_size' | |||
| EPOCH = 'epoch' | |||
| LEARNING_RATE = 'learning_rate' | |||
| class SystemDefinedTargets(BaseEnum): | |||
| """System defined targets""" | |||
| LOSS = 'loss' | |||
| @@ -48,3 +48,19 @@ class OptimizerTerminateError(MindInsightException): | |||
| super(OptimizerTerminateError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | |||
| error_msg, | |||
| http_code=400) | |||
| class ConfigParamError(MindInsightException): | |||
| """Hyper config error.""" | |||
| def __init__(self, error_msg="Invalid parameter."): | |||
| super(ConfigParamError, self).__init__(OptimizerErrors.OPTIMIZER_TERMINATE, | |||
| error_msg, | |||
| http_code=400) | |||
| class HyperConfigEnvError(MindInsightException): | |||
| """Hyper config error.""" | |||
| def __init__(self, error_msg="Hyper config is not correct."): | |||
| super(HyperConfigEnvError, self).__init__(OptimizerErrors.HYPER_CONFIG_ENV_ERROR, | |||
| error_msg, | |||
| http_code=400) | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Common validator modules for optimizer.""" | |||
| @@ -0,0 +1,202 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Validator for optimizer config.""" | |||
| import math | |||
| from marshmallow import Schema, fields, ValidationError, validates, validate, validates_schema | |||
| from mindinsight.optimizer.common.enums import TuneMethod, AcquisitionFunctionEnum, GPSupportArgs, \ | |||
| HyperParamSource, HyperParamType, TargetGoal, TargetKey, TunableSystemDefinedParams, TargetGroup, \ | |||
| HyperParamKey, SystemDefinedTargets | |||
| _BOUND_LEN = 2 | |||
| _NUMBER_ERR_MSG = "Value(s) should be integer or float." | |||
| _TYPE_ERR_MSG = "Value type should be %r." | |||
| _VALUE_ERR_MSG = "Value should be in %s. Current value is %s." | |||
| def _generate_schema_err_msg(err_msg, *args): | |||
| """Organize error messages.""" | |||
| if args: | |||
| err_msg = err_msg % args | |||
| return {"invalid": err_msg} | |||
| def include_integer(low, high): | |||
| """Check if the range [low, high) includes integer.""" | |||
| def _in_range(num, low, high): | |||
| """check if num in [low, high)""" | |||
| return low <= num < high | |||
| if _in_range(math.ceil(low), low, high) or _in_range(math.floor(high), low, high): | |||
| return True | |||
| return False | |||
| class TunerSchema(Schema): | |||
| """Schema for tuner.""" | |||
| dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") | |||
| name = fields.Str(required=True, | |||
| validate=validate.OneOf(TuneMethod.list_members()), | |||
| error_messages=_generate_schema_err_msg("Value should be in %s." % TuneMethod.list_members())) | |||
| args = fields.Dict(error_messages=dict_err_msg) | |||
| @validates("args") | |||
| def check_args(self, data): | |||
| """Check args for tuner.""" | |||
| data_keys = list(data.keys()) | |||
| support_args = GPSupportArgs.list_members() | |||
| if not set(data_keys).issubset(set(support_args)): | |||
| raise ValidationError("Only support setting %s for tuner. " | |||
| "Current key(s): %s." % (support_args, data_keys)) | |||
| method = data.get(GPSupportArgs.METHOD.value) | |||
| if not isinstance(method, str): | |||
| raise ValidationError("The 'method' type should be str.") | |||
| if method not in AcquisitionFunctionEnum.list_members(): | |||
| raise ValidationError("Supported acquisition function must be one of %s. Current value is %r." % | |||
| (AcquisitionFunctionEnum.list_members(), method)) | |||
| class ParameterSchema(Schema): | |||
| """Schema for parameter.""" | |||
| number_err_msg = _generate_schema_err_msg(_NUMBER_ERR_MSG) | |||
| list_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "list") | |||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||
| bounds = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) | |||
| choice = fields.List(fields.Number(error_messages=number_err_msg), error_messages=list_err_msg) | |||
| type = fields.Str(error_messages=list_err_msg) | |||
| source = fields.Str(error_messages=str_err_msg) | |||
| @validates("bounds") | |||
| def check_bounds(self, bounds): | |||
| """Check if bounds are valid.""" | |||
| if len(bounds) != _BOUND_LEN: | |||
| raise ValidationError("Length of bounds should be %s." % _BOUND_LEN) | |||
| if bounds[1] <= bounds[0]: | |||
| raise ValidationError("The upper bound must be greater than lower bound. " | |||
| "The range is [lower_bound, upper_bound).") | |||
| @validates("type") | |||
| def check_type(self, type_in): | |||
| """Check if type is valid.""" | |||
| if type_in not in HyperParamType.list_members(): | |||
| raise ValidationError("The type should be in %s." % HyperParamType.list_members()) | |||
| @validates("source") | |||
| def check_source(self, source): | |||
| """Check if source is valid.""" | |||
| if source not in HyperParamSource.list_members(): | |||
| raise ValidationError(_VALUE_ERR_MSG % (HyperParamSource.list_members(), source)) | |||
| @validates_schema | |||
| def check_combination(self, data, **kwargs): | |||
| """check the combination of parameters.""" | |||
| bound_key = HyperParamKey.BOUND.value | |||
| choice_key = HyperParamKey.CHOICE.value | |||
| type_key = HyperParamKey.TYPE.value | |||
| # check bound and type | |||
| bounds = data.get(bound_key) | |||
| param_type = data.get(type_key) | |||
| if bounds is not None: | |||
| if param_type is None: | |||
| raise ValidationError("If %r is specified, the %r should be specified also." % | |||
| (HyperParamKey.BOUND.value, HyperParamKey.TYPE.value)) | |||
| if param_type == HyperParamType.INT.value and not include_integer(bounds[0], bounds[1]): | |||
| raise ValidationError("No integer in 'bounds', please modify it.") | |||
| # check bound and choice | |||
| if (bound_key in data and choice_key in data) or (bound_key not in data and choice_key not in data): | |||
| raise ValidationError("Only one of [%r, %r] should be specified." % | |||
| (bound_key, choice_key)) | |||
| class TargetSchema(Schema): | |||
| """Schema for target.""" | |||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||
| group = fields.Str(error_messages=str_err_msg) | |||
| name = fields.Str(required=True, error_messages=str_err_msg) | |||
| goal = fields.Str(error_messages=str_err_msg) | |||
| @validates("group") | |||
| def check_group(self, group): | |||
| """Check if bounds are valid.""" | |||
| if group not in TargetGroup.list_members(): | |||
| raise ValidationError(_VALUE_ERR_MSG % (TargetGroup.list_members(), group)) | |||
| @validates("goal") | |||
| def check_goal(self, goal): | |||
| """Check if source is valid.""" | |||
| if goal not in TargetGoal.list_members(): | |||
| raise ValidationError(_VALUE_ERR_MSG % (TargetGoal.list_members(), goal)) | |||
| @validates_schema | |||
| def check_combination(self, data, **kwargs): | |||
| """check the combination of parameters.""" | |||
| if TargetKey.GROUP.value not in data: | |||
| # if name is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. | |||
| return | |||
| name = data.get(TargetKey.NAME.value) | |||
| group = data.get(TargetKey.GROUP.value) | |||
| if group == TargetGroup.SYSTEM_DEFINED.value and name not in SystemDefinedTargets.list_members(): | |||
| raise ValidationError({ | |||
| TargetKey.GROUP.value: "This target is not system defined. Current group is: %s." % group}) | |||
| class OptimizerConfig(Schema): | |||
| """Define the search model condition parameter schema.""" | |||
| dict_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "dict") | |||
| str_err_msg = _generate_schema_err_msg(_TYPE_ERR_MSG, "str") | |||
| summary_base_dir = fields.Str(required=True, error_messages=str_err_msg) | |||
| command = fields.Str(required=True, error_messages=str_err_msg) | |||
| tuner = fields.Dict(required=True, error_messages=dict_err_msg) | |||
| target = fields.Dict(required=True, error_messages=dict_err_msg) | |||
| parameters = fields.Dict(required=True, error_messages=dict_err_msg) | |||
| @validates("tuner") | |||
| def check_tuner(self, data): | |||
| """Check tuner.""" | |||
| err = TunerSchema().validate(data) | |||
| if err: | |||
| raise ValidationError(err) | |||
| @validates("parameters") | |||
| def check_parameters(self, parameters): | |||
| """Check parameters.""" | |||
| for name, value in parameters.items(): | |||
| err = ParameterSchema().validate(value) | |||
| if err: | |||
| raise ValidationError({name: err}) | |||
| if HyperParamKey.SOURCE.value not in value: | |||
| # if params is in system_defined keys, group will be 'system_defined', else will be 'user_defined'. | |||
| continue | |||
| source = value.get(HyperParamKey.SOURCE.value) | |||
| if source == HyperParamSource.SYSTEM_DEFINED.value and \ | |||
| name not in TunableSystemDefinedParams.list_members(): | |||
| raise ValidationError({ | |||
| name: {"source": "This param is not system defined. Current source is: %s." % source}}) | |||
| @validates("target") | |||
| def check_target(self, target): | |||
| """Check target.""" | |||
| err = TargetSchema().validate(target) | |||
| if err: | |||
| raise ValidationError(err) | |||
| @@ -15,30 +15,73 @@ | |||
| """Hyper config.""" | |||
| import json | |||
| import os | |||
| from attrdict import AttrDict | |||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | |||
| from mindinsight.optimizer.common.exceptions import HyperConfigError | |||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT | |||
| from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError | |||
| _HYPER_CONFIG_LEN_LIMIT = 100000 | |||
| class AttributeDict(dict): | |||
| """A dict can be accessed by attribute.""" | |||
| def __init__(self, d=None): | |||
| super().__init__() | |||
| if d is not None: | |||
| for k, v in d.items(): | |||
| self[k] = v | |||
| class HyperConfig: | |||
| """ | |||
| Hyper config. | |||
| def __key(self, key): | |||
| """Get key.""" | |||
| return "" if key is None else key | |||
| def __setattr__(self, key, value): | |||
| """Set attribute for object.""" | |||
| self[self.__key(key)] = value | |||
| def __getattr__(self, key): | |||
| """ | |||
| Get attribute value according by attribute name. | |||
| Args: | |||
| key (str): attribute name. | |||
| Returns: | |||
| Any, attribute value. | |||
| Init hyper config: | |||
| >>> hyper_config = HyperConfig() | |||
| Raises: | |||
| AttributeError: If the key does not exists, will raise Exception. | |||
| Get suggest params: | |||
| >>> param_obj = hyper_config.params | |||
| >>> learning_rate = params.learning_rate | |||
| """ | |||
| value = self.get(self.__key(key)) | |||
| if value is None: | |||
| raise AttributeError("The attribute %r is not exist." % key) | |||
| return value | |||
| Get summary dir: | |||
| >>> summary_dir = hyper_config.summary_dir | |||
| def __getitem__(self, key): | |||
| """Get attribute value according by attribute name.""" | |||
| value = super().get(self.__key(key)) | |||
| if value is None: | |||
| raise AttributeError("The attribute %r is not exist." % key) | |||
| return value | |||
| Record by SummaryCollector: | |||
| >>> summary_cb = SummaryCollector(summary_dir) | |||
| def __setitem__(self, key, value): | |||
| """Set attribute for object.""" | |||
| return super().__setitem__(self.__key(key), value) | |||
| class HyperConfig: | |||
| """ | |||
| Hyper config. | |||
| 1. Init HyperConfig. | |||
| 2. Get suggested params and summary_dir. | |||
| 3. Record by SummaryCollector with summary_dir. | |||
| Examples: | |||
| >>> hyper_config = HyperConfig() | |||
| >>> params = hyper_config.params | |||
| >>> learning_rate = params.learning_rate | |||
| >>> batch_size = params.batch_size | |||
| >>> summary_dir = hyper_config.summary_dir | |||
| >>> summary_cb = SummaryCollector(summary_dir) | |||
| """ | |||
| def __init__(self): | |||
| self._init_validate_hyper_config() | |||
| @@ -47,10 +90,10 @@ class HyperConfig: | |||
| """Init and validate hyper config.""" | |||
| hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME) | |||
| if hyper_config is None: | |||
| raise HyperConfigError("Hyper config is not in system environment.") | |||
| if len(hyper_config) > _HYPER_CONFIG_LEN_LIMIT: | |||
| raise HyperConfigError("Hyper config is too long. The length limit is %s, the length of " | |||
| "hyper_config is %s." % (_HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||
| raise HyperConfigEnvError("Hyper config is not in system environment.") | |||
| if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT: | |||
| raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of " | |||
| "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config))) | |||
| try: | |||
| hyper_config = json.loads(hyper_config) | |||
| @@ -60,8 +103,7 @@ class HyperConfig: | |||
| raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc)) | |||
| self._validate_hyper_config(hyper_config) | |||
| self._summary_dir = hyper_config.get('summary_dir') | |||
| self._param_obj = AttrDict(hyper_config.get('params')) | |||
| self._hyper_config = hyper_config | |||
| def _validate_hyper_config(self, hyper_config): | |||
| """Validate hyper config.""" | |||
| @@ -86,9 +128,13 @@ class HyperConfig: | |||
| @property | |||
| def params(self): | |||
| """Get params.""" | |||
| return self._param_obj | |||
| return AttributeDict(self._hyper_config.get('params')) | |||
| @property | |||
| def summary_dir(self): | |||
| """Get train summary dir path.""" | |||
| return self._summary_dir | |||
| return self._hyper_config.get('summary_dir') | |||
| @property | |||
| def custom_lineage_data(self): | |||
| return self._hyper_config.get('custom_lineage_data') | |||
| @@ -22,16 +22,16 @@ import yaml | |||
| from marshmallow import ValidationError | |||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | |||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||
| from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path | |||
| from mindinsight.lineagemgr.model import get_lineage_table, LineageTable, METRIC_PREFIX | |||
| from mindinsight.lineagemgr.model import get_lineage_table, LineageTable | |||
| from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME | |||
| from mindinsight.optimizer.common.enums import TuneMethod, TargetKey, TargetGoal | |||
| from mindinsight.optimizer.common.exceptions import OptimizerTerminateError | |||
| from mindinsight.optimizer.common.enums import TuneMethod | |||
| from mindinsight.optimizer.common.exceptions import OptimizerTerminateError, ConfigParamError | |||
| from mindinsight.optimizer.common.log import logger | |||
| from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig | |||
| from mindinsight.optimizer.tuners.gp_tuner import GPBaseTuner | |||
| from mindinsight.optimizer.utils.param_handler import organize_params_target | |||
| from mindinsight.optimizer.utils.utils import get_nested_message | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError, FileSystemPermissionError, UnknownError | |||
| _OK = 0 | |||
| @@ -51,7 +51,6 @@ class Tuner: | |||
| def __init__(self, config_path: str): | |||
| self._config_info = self._validate_config(config_path) | |||
| self._summary_base_dir = self._config_info.get('summary_base_dir') | |||
| self._data_manager = self._init_data_manager() | |||
| self._dir_prefix = 'train' | |||
| def _validate_config(self, config_path): | |||
| @@ -65,11 +64,17 @@ class Tuner: | |||
| except Exception as exc: | |||
| raise UnknownError("Detail: %s." % str(exc)) | |||
| # need to add validation for config_info: command, summary_base_dir, target and params. | |||
| self._validate_config_schema(config_info) | |||
| config_info['summary_base_dir'] = self._normalize_path("summary_base_dir", config_info.get('summary_base_dir')) | |||
| self._make_summary_base_dir(config_info['summary_base_dir']) | |||
| return config_info | |||
| def _validate_config_schema(self, config_info): | |||
| error = OptimizerConfig().validate(config_info) | |||
| if error: | |||
| err_msg = get_nested_message(error) | |||
| raise ConfigParamError(err_msg) | |||
| def _make_summary_base_dir(self, summary_base_dir): | |||
| """Check and make summary_base_dir.""" | |||
| if not os.path.exists(summary_base_dir): | |||
| @@ -82,13 +87,6 @@ class Tuner: | |||
| except OSError as exc: | |||
| raise UnknownError("Can not make the summary base directory. Detail: %s." % str(exc)) | |||
| def _init_data_manager(self): | |||
| """Initialize data_manager.""" | |||
| data_manager = DataManager(summary_base_dir=self._summary_base_dir) | |||
| data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||
| return data_manager | |||
| def _normalize_path(self, param_name, path): | |||
| """Normalize config path.""" | |||
| path = os.path.realpath(path) | |||
| @@ -104,10 +102,8 @@ class Tuner: | |||
| def _update_from_lineage(self): | |||
| """Update lineage from lineagemgr.""" | |||
| self._data_manager.start_load_data(reload_interval=0).join() | |||
| try: | |||
| lineage_table = get_lineage_table(self._data_manager) | |||
| lineage_table = get_lineage_table(summary_base_dir=self._summary_base_dir) | |||
| except MindInsightException as err: | |||
| logger.info("Can not query lineage. Detail: %s", str(err)) | |||
| lineage_table = None | |||
| @@ -122,12 +118,14 @@ class Tuner: | |||
| tuner = self._config_info.get('tuner') | |||
| for _ in range(max_expr_times): | |||
| self._update_from_lineage() | |||
| suggestion = self._suggest(self._lineage_table, params_info, target_info, method=tuner.get("name")) | |||
| suggestion, user_defined_info = self._suggest(self._lineage_table, params_info, target_info, tuner) | |||
| hyper_config = { | |||
| 'params': suggestion, | |||
| 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}') | |||
| 'summary_dir': os.path.join(self._summary_base_dir, f'{self._dir_prefix}_{str(uuid.uuid1())}'), | |||
| 'custom_lineage_data': user_defined_info | |||
| } | |||
| logger.info("Suggest values are: %s.", suggestion) | |||
| os.environ[HYPER_CONFIG_ENV_NAME] = json.dumps(hyper_config) | |||
| s = subprocess.Popen(shlex.split(command)) | |||
| s.wait() | |||
| @@ -135,29 +133,26 @@ class Tuner: | |||
| logger.error("An error occurred during execution, the auto tuning will be terminated.") | |||
| raise OptimizerTerminateError("An error occurred during execution, the auto tuning was terminated.") | |||
| def _get_tuner(self, tune_method=TuneMethod.GP.value): | |||
| def _get_tuner(self, tuner): | |||
| """Get tuner.""" | |||
| if tune_method.lower() not in TuneMethod.list_members(): | |||
| if tuner is None: | |||
| return GPBaseTuner() | |||
| tuner_name = tuner.get("name").lower() | |||
| if tuner_name not in TuneMethod.list_members(): | |||
| raise ParamValueError("'tune_method' should in %s." % TuneMethod.list_members()) | |||
| args = tuner.get("args") | |||
| if args is not None and args.get("method") is not None: | |||
| return GPBaseTuner(args.get("method")) | |||
| # Only support gaussian process regressor currently. | |||
| return GPBaseTuner() | |||
| def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, method): | |||
| def _suggest(self, lineage_table: LineageTable, params_info: dict, target_info: dict, tuner): | |||
| """Get suggestions for targets.""" | |||
| tuner = self._get_tuner(method) | |||
| target_name = target_info[TargetKey.NAME.value] | |||
| if TargetKey.GROUP.value in target_info and target_info[TargetKey.GROUP.value] == 'metric': | |||
| target_name = METRIC_PREFIX + target_name | |||
| param_matrix, target_matrix = organize_params_target(lineage_table, params_info, target_name) | |||
| if not param_matrix.empty: | |||
| suggestion = tuner.suggest([], [], params_info) | |||
| else: | |||
| target_column = target_matrix[target_name].reshape((-1, 1)) | |||
| if target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: | |||
| target_column = -target_column | |||
| suggestion = tuner.suggest(param_matrix, target_column, params_info) | |||
| return suggestion | |||
| tuner = self._get_tuner(tuner) | |||
| param_matrix, target_column = organize_params_target(lineage_table, params_info, target_info) | |||
| suggestion, user_defined_info = tuner.suggest(param_matrix, target_column, params_info) | |||
| return suggestion, user_defined_info | |||
| @@ -22,10 +22,11 @@ from sklearn.gaussian_process import GaussianProcessRegressor | |||
| from sklearn.gaussian_process.kernels import Matern | |||
| from mindinsight.optimizer.common.enums import AcquisitionFunctionEnum, HyperParamKey | |||
| from mindinsight.optimizer.common.log import logger | |||
| from mindinsight.optimizer.tuners.base_tuner import BaseTuner | |||
| from mindinsight.optimizer.utils.param_handler import generate_arrays, match_value_type | |||
| from mindinsight.optimizer.utils.transformer import Transformer | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.optimizer.tuners.base_tuner import BaseTuner | |||
| class AcquisitionFunction: | |||
| @@ -141,8 +142,10 @@ class GPBaseTuner(BaseTuner): | |||
| x_seeds = generate_arrays(params_info, n_iter) | |||
| for x_try in x_seeds: | |||
| res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), | |||
| x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") | |||
| with warnings.catch_warnings(): | |||
| warnings.simplefilter("ignore") | |||
| res = minimize(lambda x: -self._utility_function.ac(x.reshape(1, -1), gp=gp, y_max=y_max), | |||
| x_try.reshape(1, -1), bounds=bounds, method="L-BFGS-B") | |||
| if not res.success: | |||
| continue | |||
| @@ -164,6 +167,8 @@ class GPBaseTuner(BaseTuner): | |||
| min_lineage_rows = 2 | |||
| if not np.array(params).any() or params.shape[0] < min_lineage_rows: | |||
| logger.info("Without valid histories or the rows of lineages < %s, " | |||
| "parameters will be recommended randomly.", min_lineage_rows) | |||
| suggestion = generate_arrays(params_info) | |||
| else: | |||
| self._gp.fit(params, target) | |||
| @@ -174,5 +179,5 @@ class GPBaseTuner(BaseTuner): | |||
| params_info=params_info | |||
| ) | |||
| suggestion = Transformer.transform_list_to_dict(params_info, suggestion) | |||
| return suggestion | |||
| suggestion, user_defined_info = Transformer.transform_list_to_dict(params_info, suggestion) | |||
| return suggestion, user_defined_info | |||
| @@ -14,8 +14,9 @@ | |||
| # ============================================================================ | |||
| """Utils for params.""" | |||
| import numpy as np | |||
| from mindinsight.lineagemgr.model import LineageTable | |||
| from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType | |||
| from mindinsight.lineagemgr.model import LineageTable, USER_DEFINED_PREFIX, METRIC_PREFIX | |||
| from mindinsight.optimizer.common.enums import HyperParamKey, HyperParamType, HyperParamSource, TargetKey, \ | |||
| TargetGoal, TunableSystemDefinedParams, TargetGroup, SystemDefinedTargets | |||
| from mindinsight.optimizer.common.log import logger | |||
| @@ -54,7 +55,6 @@ def match_value_type(array, params_info: dict): | |||
| array_new = [] | |||
| index = 0 | |||
| for _, param_info in params_info.items(): | |||
| param_type = param_info[HyperParamKey.TYPE.value] | |||
| value = array[index] | |||
| if HyperParamKey.BOUND.value in param_info: | |||
| bound = param_info[HyperParamKey.BOUND.value] | |||
| @@ -64,7 +64,7 @@ def match_value_type(array, params_info: dict): | |||
| choices = param_info[HyperParamKey.CHOICE.value] | |||
| nearest_index = int(np.argmin(np.fabs(np.array(choices) - value))) | |||
| value = choices[nearest_index] | |||
| if param_type == HyperParamType.INT.value: | |||
| if param_info.get(HyperParamKey.TYPE.value) == HyperParamType.INT.value: | |||
| value = int(value) | |||
| if HyperParamKey.DECIMAL.value in param_info: | |||
| value = np.around(value, decimals=param_info[HyperParamKey.DECIMAL.value]) | |||
| @@ -73,20 +73,51 @@ def match_value_type(array, params_info: dict): | |||
| return array_new | |||
| def organize_params_target(lineage_table: LineageTable, params_info: dict, target_name): | |||
| def organize_params_target(lineage_table: LineageTable, params_info: dict, target_info): | |||
| """Organize params and target.""" | |||
| empty_result = np.array([]) | |||
| if lineage_table is None: | |||
| return empty_result, empty_result | |||
| param_keys = list(params_info.keys()) | |||
| param_keys = [] | |||
| for param_key, param_info in params_info.items(): | |||
| # It will be a user_defined param: | |||
| # 1. if 'source' is specified as 'user_defined' | |||
| # 2. if 'source' is not specified and the param is not a system_defined key | |||
| source = param_info.get(HyperParamKey.SOURCE.value) | |||
| prefix = _get_prefix(param_key, source, HyperParamSource.USER_DEFINED.value, | |||
| USER_DEFINED_PREFIX, TunableSystemDefinedParams.list_members()) | |||
| param_key = f'{prefix}{param_key}' | |||
| if prefix == USER_DEFINED_PREFIX: | |||
| param_info[HyperParamKey.SOURCE.value] = HyperParamSource.USER_DEFINED.value | |||
| else: | |||
| param_info[HyperParamKey.SOURCE.value] = HyperParamSource.SYSTEM_DEFINED.value | |||
| param_keys.append(param_key) | |||
| target_name = target_info[TargetKey.NAME.value] | |||
| group = target_info.get(TargetKey.GROUP.value) | |||
| prefix = _get_prefix(target_name, group, TargetGroup.METRIC.value, | |||
| METRIC_PREFIX, SystemDefinedTargets.list_members()) | |||
| target_name = prefix + target_name | |||
| lineage_df = lineage_table.dataframe_data | |||
| try: | |||
| lineage_df = lineage_df[param_keys + [target_name]] | |||
| lineage_df = lineage_df.dropna(axis=0, how='any') | |||
| return lineage_df[param_keys], lineage_df[target_name] | |||
| target_column = np.array(lineage_df[target_name]) | |||
| if TargetKey.GOAL.value in target_info and \ | |||
| target_info.get(TargetKey.GOAL.value) == TargetGoal.MAXIMUM.value: | |||
| target_column = -target_column | |||
| return np.array(lineage_df[param_keys]), target_column | |||
| except KeyError as exc: | |||
| logger.warning("Some keys not exist in specified params or target. It will suggest params randomly." | |||
| "Detail: %s.", str(exc)) | |||
| return empty_result, empty_result | |||
| def _get_prefix(name, field, other_defined_field, other_defined_prefix, system_defined_fields): | |||
| if (field == other_defined_field) or (field is None and name not in system_defined_fields): | |||
| return other_defined_prefix | |||
| return '' | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Transformer.""" | |||
| from mindinsight.optimizer.utils.param_handler import match_value_type | |||
| from mindinsight.optimizer.common.enums import HyperParamSource, HyperParamKey | |||
| class Transformer: | |||
| @@ -23,7 +24,12 @@ class Transformer: | |||
| """Transform from tuner.""" | |||
| suggest_list = match_value_type(suggest_list, params_info) | |||
| param_dict = {} | |||
| user_defined_info = {} | |||
| for index, param_name in enumerate(params_info): | |||
| param_dict.update({param_name: suggest_list[index]}) | |||
| param_item = {param_name: suggest_list[index]} | |||
| param_dict.update(param_item) | |||
| source = params_info.get(param_name).get(HyperParamKey.SOURCE.value) | |||
| if source is not None and source == HyperParamSource.USER_DEFINED.value: | |||
| user_defined_info.update(param_item) | |||
| return param_dict | |||
| return param_dict, user_defined_info | |||
| @@ -68,3 +68,18 @@ def is_simple_numpy_number(dtype): | |||
| return True | |||
| return False | |||
| def get_nested_message(info: dict, out_err_msg=""): | |||
| """Get error message from the error dict generated by schema validation.""" | |||
| if not isinstance(info, dict): | |||
| if isinstance(info, list): | |||
| info = info[0] | |||
| return f'Error in {out_err_msg}: {info}' | |||
| for key in info: | |||
| if isinstance(key, str) and key != '_schema': | |||
| if out_err_msg: | |||
| out_err_msg = f'{out_err_msg}.{key}' | |||
| else: | |||
| out_err_msg = key | |||
| return get_nested_message(info[key], out_err_msg) | |||
| @@ -98,3 +98,5 @@ class OptimizerErrors(Enum): | |||
| CORRELATION_NAN = 2 | |||
| HYPER_CONFIG_ERROR = 3 | |||
| OPTIMIZER_TERMINATE = 4 | |||
| CONFIG_PARAM_ERROR = 5 | |||
| HYPER_CONFIG_ENV_ERROR = 6 | |||
| @@ -10,9 +10,9 @@ marshmallow>=2.19.2 | |||
| numpy>=1.17.0 | |||
| protobuf>=3.8.0 | |||
| psutil>=5.6.1 | |||
| pyyaml>=5.3 | |||
| pyyaml>=5.3.1 | |||
| scipy>=1.3.3 | |||
| scikit-learn>=0.23.1 | |||
| scikit-learn>=0.21.2 | |||
| six>=1.12.0 | |||
| Werkzeug>=1.0.0 | |||
| pandas>=1.0.4 | |||
| @@ -208,6 +208,7 @@ if __name__ == '__main__': | |||
| 'mindinsight=mindinsight.utils.command:main', | |||
| 'mindconverter=mindinsight.mindconverter.cli:cli_entry', | |||
| 'mindwizard=mindinsight.wizard.cli:cli_entry', | |||
| 'mindoptimizer=mindinsight.optimizer.cli:cli_entry', | |||
| ], | |||
| }, | |||
| python_requires='>=3.7', | |||
| @@ -98,14 +98,14 @@ class TestModelLineage(TestCase): | |||
| train_callback.end(RunContext(self.run_context)) | |||
| LINEAGE_DATA_MANAGER.start_load_data().join() | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||
| assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 | |||
| run_context = self.run_context | |||
| run_context['epoch_num'] = 14 | |||
| train_callback.end(RunContext(run_context)) | |||
| LINEAGE_DATA_MANAGER.start_load_data().join() | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||
| assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 | |||
| @pytest.mark.scene_eval(3) | |||
| @@ -198,7 +198,7 @@ class TestModelLineage(TestCase): | |||
| train_callback.end(RunContext(run_context_customized)) | |||
| LINEAGE_DATA_MANAGER.start_load_data().join() | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, self._search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) | |||
| assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ | |||
| == 'SoftmaxCrossEntropyWithLogits' | |||
| assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' | |||
| @@ -190,7 +190,7 @@ class TestModelApi(TestCase): | |||
| search_condition = { | |||
| 'sorted_name': 'summary_dir' | |||
| } | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| @@ -228,7 +228,7 @@ class TestModelApi(TestCase): | |||
| ], | |||
| 'count': 2 | |||
| } | |||
| partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||
| partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| @@ -266,7 +266,7 @@ class TestModelApi(TestCase): | |||
| ], | |||
| 'count': 2 | |||
| } | |||
| partial_res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||
| partial_res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| @@ -295,7 +295,7 @@ class TestModelApi(TestCase): | |||
| ], | |||
| 'count': 3 | |||
| } | |||
| partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) | |||
| partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res1.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| @@ -314,7 +314,7 @@ class TestModelApi(TestCase): | |||
| 'object': [], | |||
| 'count': 0 | |||
| } | |||
| partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) | |||
| partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) | |||
| assert expect_result == partial_res2 | |||
| @pytest.mark.level0 | |||
| @@ -335,7 +335,7 @@ class TestModelApi(TestCase): | |||
| 'eq': self._empty_train_id | |||
| } | |||
| } | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||
| assert expect_result == res | |||
| @pytest.mark.level0 | |||
| @@ -366,7 +366,7 @@ class TestModelApi(TestCase): | |||
| ], | |||
| 'count': 1 | |||
| } | |||
| res = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition) | |||
| res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition) | |||
| assert expect_result == res | |||
| @pytest.mark.level0 | |||
| @@ -386,6 +386,7 @@ class TestModelApi(TestCase): | |||
| 'The search_condition element summary_dir should be dict.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -398,6 +399,7 @@ class TestModelApi(TestCase): | |||
| 'The sorted_name must exist when sorted_type exists.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -408,6 +410,7 @@ class TestModelApi(TestCase): | |||
| 'Invalid search_condition type, it should be dict.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -420,6 +423,7 @@ class TestModelApi(TestCase): | |||
| 'The limit must be int.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -440,6 +444,7 @@ class TestModelApi(TestCase): | |||
| 'The offset must be int.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -454,6 +459,7 @@ class TestModelApi(TestCase): | |||
| 'The search attribute not supported.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -475,6 +481,7 @@ class TestModelApi(TestCase): | |||
| 'The sorted_type must be ascending or descending', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -489,6 +496,7 @@ class TestModelApi(TestCase): | |||
| 'The compare condition should be in', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -503,6 +511,7 @@ class TestModelApi(TestCase): | |||
| 'The parameter metric/accuracy is invalid.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -526,7 +535,7 @@ class TestModelApi(TestCase): | |||
| 'object': [], | |||
| 'count': 0 | |||
| } | |||
| partial_res1 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition1) | |||
| partial_res1 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition1) | |||
| assert expect_result == partial_res1 | |||
| # the (offset + 1) * limit > count | |||
| @@ -542,7 +551,7 @@ class TestModelApi(TestCase): | |||
| 'object': [], | |||
| 'count': 2 | |||
| } | |||
| partial_res2 = filter_summary_lineage(LINEAGE_DATA_MANAGER, search_condition2) | |||
| partial_res2 = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=search_condition2) | |||
| assert expect_result == partial_res2 | |||
| @pytest.mark.level0 | |||
| @@ -566,6 +575,7 @@ class TestModelApi(TestCase): | |||
| f'The parameter {condition_key} is invalid. Its operation should be `eq`, `in` or `not_in`.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -589,6 +599,7 @@ class TestModelApi(TestCase): | |||
| "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -610,6 +621,7 @@ class TestModelApi(TestCase): | |||
| 'The sorted_name must be in', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| None, | |||
| search_condition | |||
| ) | |||
| @@ -23,7 +23,7 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError | |||
| class TestValidateSearchModelCondition(TestCase): | |||
| """Test the mothod of validate_search_model_condition.""" | |||
| """Test the method of validate_search_model_condition.""" | |||
| def test_validate_search_model_condition_param_type_error(self): | |||
| """Test the method of validate_search_model_condition with LineageParamTypeError.""" | |||
| condition = { | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test common module.""" | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test validator.""" | |||
| @@ -0,0 +1,161 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test optimizer config schema.""" | |||
| from copy import deepcopy | |||
| from mindinsight.optimizer.common.validator.optimizer_config import OptimizerConfig | |||
| from mindinsight.optimizer.common.enums import TargetGroup, HyperParamSource | |||
| _BASE_DATA = { | |||
| 'command': 'python ./test.py', | |||
| 'summary_base_dir': './demo_lineage_r0.3', | |||
| 'tuner': { | |||
| 'name': 'gp', | |||
| 'args': { | |||
| 'method': 'ucb' | |||
| } | |||
| }, | |||
| 'target': { | |||
| 'group': 'metric', | |||
| 'name': 'Accuracy', | |||
| 'goal': 'maximize' | |||
| }, | |||
| 'parameters': { | |||
| 'learning_rate': { | |||
| 'bounds': [0.0001, 0.001], | |||
| 'type': 'float' | |||
| }, | |||
| 'batch_size': { | |||
| 'choice': [32, 64, 128, 256], | |||
| 'type': 'int' | |||
| }, | |||
| 'decay_step': { | |||
| 'choice': [20], | |||
| 'type': 'int' | |||
| } | |||
| } | |||
| } | |||
| class TestOptimizerConfig: | |||
| """Test the method of validate_search_model_condition.""" | |||
| _config_dict = dict(_BASE_DATA) | |||
| def test_config_dict_with_wrong_type(self): | |||
| """Test config dict with wrong type.""" | |||
| config_dict = deepcopy(self._config_dict) | |||
| init_list = ['a'] | |||
| init_str = 'a' | |||
| config_dict['command'] = init_list | |||
| config_dict['summary_base_dir'] = init_list | |||
| config_dict['target']['name'] = init_list | |||
| config_dict['target']['goal'] = init_list | |||
| config_dict['parameters']['learning_rate']['bounds'] = init_str | |||
| config_dict['parameters']['learning_rate']['choice'] = init_str | |||
| expected_err = { | |||
| 'command': ["Value type should be 'str'."], | |||
| 'parameters': { | |||
| 'learning_rate': { | |||
| 'bounds': ["Value type should be 'list'."], | |||
| 'choice': ["Value type should be 'list'."] | |||
| } | |||
| }, | |||
| 'summary_base_dir': ["Value type should be 'str'."], | |||
| 'target': { | |||
| 'name': ["Value type should be 'str'."], | |||
| 'goal': ["Value type should be 'str'."] | |||
| } | |||
| } | |||
| err = OptimizerConfig().validate(config_dict) | |||
| assert expected_err == err | |||
| def test_config_dict_with_wrong_value(self): | |||
| """Test config dict with wrong value.""" | |||
| config_dict = deepcopy(self._config_dict) | |||
| init_list = ['a'] | |||
| init_str = 'a' | |||
| config_dict['target']['group'] = init_str | |||
| config_dict['target']['goal'] = init_str | |||
| config_dict['tuner']['name'] = init_str | |||
| config_dict['parameters']['learning_rate']['bounds'] = init_list | |||
| config_dict['parameters']['learning_rate']['choice'] = init_list | |||
| config_dict['parameters']['learning_rate']['type'] = init_str | |||
| expected_err = { | |||
| 'parameters': { | |||
| 'learning_rate': { | |||
| 'bounds': { | |||
| 0: ['Value(s) should be integer or float.'] | |||
| }, | |||
| 'choice': { | |||
| 0: ['Value(s) should be integer or float.'] | |||
| }, | |||
| 'type': ["The type should be in ['int', 'float']."] | |||
| } | |||
| }, | |||
| 'target': { | |||
| 'goal': ["Value should be in ['maximize', 'minimize']. Current value is a."], | |||
| 'group': ["Value should be in ['system_defined', 'metric']. Current value is a."]}, | |||
| 'tuner': { | |||
| 'name': ['Must be one of: gp.'] | |||
| } | |||
| } | |||
| err = OptimizerConfig().validate(config_dict) | |||
| assert expected_err == err | |||
| def test_target_combination(self): | |||
| """Test target combination.""" | |||
| config_dict = deepcopy(self._config_dict) | |||
| config_dict['target']['group'] = TargetGroup.SYSTEM_DEFINED.value | |||
| config_dict['target']['name'] = 'a' | |||
| expected_err = { | |||
| 'target': { | |||
| 'group': 'This target is not system defined. Current group is: system_defined.' | |||
| } | |||
| } | |||
| err = OptimizerConfig().validate(config_dict) | |||
| assert expected_err == err | |||
| def test_parameters_combination1(self): | |||
| """Test parameters combination.""" | |||
| config_dict = deepcopy(self._config_dict) | |||
| config_dict['parameters']['decay_step']['source'] = HyperParamSource.SYSTEM_DEFINED.value | |||
| expected_err = { | |||
| 'parameters': { | |||
| 'decay_step': { | |||
| 'source': 'This param is not system defined. Current source is: system_defined.' | |||
| } | |||
| } | |||
| } | |||
| err = OptimizerConfig().validate(config_dict) | |||
| assert expected_err == err | |||
| def test_parameters_combination2(self): | |||
| """Test parameters combination.""" | |||
| config_dict = deepcopy(self._config_dict) | |||
| config_dict['parameters']['decay_step']['bounds'] = [1, 40] | |||
| expected_err = { | |||
| 'parameters': { | |||
| 'decay_step': { | |||
| '_schema': ["Only one of ['bounds', 'choice'] should be specified."] | |||
| } | |||
| } | |||
| } | |||
| err = OptimizerConfig().validate(config_dict) | |||
| assert expected_err == err | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test utils.""" | |||