Merge pull request !133 from luopengting/lineage_added_infotags/v0.3.0-alpha
| @@ -18,7 +18,7 @@ import os | |||||
| from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob | from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob | ||||
| from mindinsight.lineagemgr.common.log import logger | from mindinsight.lineagemgr.common.log import logger | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError | from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError | ||||
| from mindinsight.lineagemgr.common.validator.validate import validate_train_id | |||||
| from mindinsight.lineagemgr.common.validator.validate import validate_train_id, validate_added_info | |||||
| from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE | from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| @@ -26,6 +26,7 @@ from mindinsight.utils.exceptions import ParamValueError | |||||
| def update_lineage_object(data_manager, train_id, added_info: dict): | def update_lineage_object(data_manager, train_id, added_info: dict): | ||||
| """Update lineage objects about tag and remark.""" | """Update lineage objects about tag and remark.""" | ||||
| validate_train_id(train_id) | validate_train_id(train_id) | ||||
| validate_added_info(added_info) | |||||
| cache_item = data_manager.get_brief_train_job(train_id) | cache_item = data_manager.get_brief_train_job(train_id) | ||||
| lineage_item = cache_item.get(key=LINEAGE, raise_exception=False) | lineage_item = cache_item.get(key=LINEAGE, raise_exception=False) | ||||
| if lineage_item is None: | if lineage_item is None: | ||||
| @@ -362,8 +362,9 @@ def validate_condition(search_condition): | |||||
| log.error(err_msg) | log.error(err_msg) | ||||
| raise LineageParamValueError(err_msg) | raise LineageParamValueError(err_msg) | ||||
| if not (sorted_name in FIELD_MAPPING | if not (sorted_name in FIELD_MAPPING | ||||
| or (sorted_name.startswith('metric/') and len(sorted_name) > 7) | |||||
| or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13)): | |||||
| or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/')) | |||||
| or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/')) | |||||
| or sorted_name in ['tag']): | |||||
| log.error(err_msg) | log.error(err_msg) | ||||
| raise LineageParamValueError(err_msg) | raise LineageParamValueError(err_msg) | ||||
| @@ -460,3 +461,54 @@ def validate_train_id(relative_path): | |||||
| raise ParamValueError( | raise ParamValueError( | ||||
| "Summary dir should be relative path starting with './'." | "Summary dir should be relative path starting with './'." | ||||
| ) | ) | ||||
| def validate_range(name, value, min_value, max_value): | |||||
| """ | |||||
| Check if value is in [min_value, max_value]. | |||||
| Args: | |||||
| name (str): Value name. | |||||
| value (Union[int, float]): Value to be check. | |||||
| min_value (Union[int, float]): Min value. | |||||
| max_value (Union[int, float]): Max value. | |||||
| Raises: | |||||
| LineageParamValueError, if value type is invalid or value is out of [min_value, max_value]. | |||||
| """ | |||||
| if not isinstance(value, (int, float)): | |||||
| raise LineageParamValueError("Value should be int or float.") | |||||
| if value < min_value or value > max_value: | |||||
| raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value)) | |||||
| def validate_added_info(added_info: dict): | |||||
| """ | |||||
| Check if added_info is valid. | |||||
| Args: | |||||
| added_info (dict): The added info. | |||||
| Raises: | |||||
| bool, if added_info is valid, return True. | |||||
| """ | |||||
| added_info_keys = ["tag", "remark"] | |||||
| if not set(added_info.keys()).issubset(added_info_keys): | |||||
| err_msg = "Keys must be in {}.".format(added_info_keys) | |||||
| log.error(err_msg) | |||||
| raise LineageParamValueError(err_msg) | |||||
| for key, value in added_info.items(): | |||||
| if key == "tag": | |||||
| if not isinstance(value, int): | |||||
| raise LineageParamValueError("'tag' must be int.") | |||||
| # tag should be in [0, 10]. | |||||
| validate_range("tag", value, min_value=0, max_value=10) | |||||
| elif key == "remark": | |||||
| if not isinstance(value, str): | |||||
| raise LineageParamValueError("'remark' must be str.") | |||||
| # length of remark should be in [0, 128]. | |||||
| validate_range("length of remark", len(value), min_value=0, max_value=128) | |||||
| @@ -271,25 +271,6 @@ class Querier: | |||||
| return False | return False | ||||
| return True | return True | ||||
| def _cmp(obj1: SuperLineageObj, obj2: SuperLineageObj): | |||||
| value1 = obj1.lineage_obj.get_value_by_key(sorted_name) | |||||
| value2 = obj2.lineage_obj.get_value_by_key(sorted_name) | |||||
| if value1 is None and value2 is None: | |||||
| cmp_result = 0 | |||||
| elif value1 is None: | |||||
| cmp_result = -1 | |||||
| elif value2 is None: | |||||
| cmp_result = 1 | |||||
| else: | |||||
| try: | |||||
| cmp_result = (value1 > value2) - (value1 < value2) | |||||
| except TypeError: | |||||
| type1 = type(value1).__name__ | |||||
| type2 = type(value2).__name__ | |||||
| cmp_result = (type1 > type2) - (type1 < type2) | |||||
| return cmp_result | |||||
| if condition is None: | if condition is None: | ||||
| condition = {} | condition = {} | ||||
| @@ -298,19 +279,7 @@ class Querier: | |||||
| super_lineage_objs.sort(key=lambda x: x.update_time, reverse=True) | super_lineage_objs.sort(key=lambda x: x.update_time, reverse=True) | ||||
| results = list(filter(_filter, super_lineage_objs)) | results = list(filter(_filter, super_lineage_objs)) | ||||
| if ConditionParam.SORTED_NAME.value in condition: | |||||
| sorted_name = condition.get(ConditionParam.SORTED_NAME.value) | |||||
| if self._is_valid_field(sorted_name): | |||||
| raise LineageQuerierParamException( | |||||
| 'condition', | |||||
| 'The sorted name {} not supported.'.format(sorted_name) | |||||
| ) | |||||
| sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) | |||||
| reverse = sorted_type == 'descending' | |||||
| results = sorted( | |||||
| results, key=functools.cmp_to_key(_cmp), reverse=reverse | |||||
| ) | |||||
| results = self._sorted_results(results, condition) | |||||
| offset_results = self._handle_limit_and_offset(condition, results) | offset_results = self._handle_limit_and_offset(condition, results) | ||||
| @@ -338,6 +307,55 @@ class Querier: | |||||
| return lineage_info | return lineage_info | ||||
| def _sorted_results(self, results, condition): | |||||
| """Get sorted results.""" | |||||
| def _cmp(value1, value2): | |||||
| if value1 is None and value2 is None: | |||||
| cmp_result = 0 | |||||
| elif value1 is None: | |||||
| cmp_result = -1 | |||||
| elif value2 is None: | |||||
| cmp_result = 1 | |||||
| else: | |||||
| try: | |||||
| cmp_result = (value1 > value2) - (value1 < value2) | |||||
| except TypeError: | |||||
| type1 = type(value1).__name__ | |||||
| type2 = type(value2).__name__ | |||||
| cmp_result = (type1 > type2) - (type1 < type2) | |||||
| return cmp_result | |||||
| def _cmp_added_info(obj1: SuperLineageObj, obj2: SuperLineageObj): | |||||
| value1 = obj1.added_info.get(sorted_name) | |||||
| value2 = obj2.added_info.get(sorted_name) | |||||
| return _cmp(value1, value2) | |||||
| def _cmp_super_lineage_obj(obj1: SuperLineageObj, obj2: SuperLineageObj): | |||||
| value1 = obj1.lineage_obj.get_value_by_key(sorted_name) | |||||
| value2 = obj2.lineage_obj.get_value_by_key(sorted_name) | |||||
| return _cmp(value1, value2) | |||||
| if ConditionParam.SORTED_NAME.value in condition: | |||||
| sorted_name = condition.get(ConditionParam.SORTED_NAME.value) | |||||
| sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) | |||||
| reverse = sorted_type == 'descending' | |||||
| if sorted_name in ['tag']: | |||||
| results = sorted( | |||||
| results, key=functools.cmp_to_key(_cmp_added_info), reverse=reverse | |||||
| ) | |||||
| return results | |||||
| if self._is_valid_field(sorted_name): | |||||
| raise LineageQuerierParamException( | |||||
| 'condition', | |||||
| 'The sorted name {} not supported.'.format(sorted_name) | |||||
| ) | |||||
| results = sorted( | |||||
| results, key=functools.cmp_to_key(_cmp_super_lineage_obj), reverse=reverse | |||||
| ) | |||||
| return results | |||||
| def _organize_customized(self, offset_results): | def _organize_customized(self, offset_results): | ||||
| """Organize customized.""" | """Organize customized.""" | ||||
| customized = dict() | customized = dict() | ||||
| @@ -403,8 +421,8 @@ class Querier: | |||||
| Returns: | Returns: | ||||
| bool, `True` if the field name is valid, else `False`. | bool, `True` if the field name is valid, else `False`. | ||||
| """ | """ | ||||
| return field_name not in FIELD_MAPPING and \ | |||||
| not field_name.startswith(('metric/', 'user_defined/')) | |||||
| return field_name not in FIELD_MAPPING \ | |||||
| and not field_name.startswith(('metric/', 'user_defined/')) | |||||
| def _handle_limit_and_offset(self, condition, result): | def _handle_limit_and_offset(self, condition, result): | ||||
| """ | """ | ||||