Browse Source

!133 add sorting and checking for added_info

Merge pull request !133 from luopengting/lineage_added_info
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
f379b2046f
3 changed files with 108 additions and 37 deletions
  1. +2
    -1
      mindinsight/lineagemgr/cache_item_updater.py
  2. +54
    -2
      mindinsight/lineagemgr/common/validator/validate.py
  3. +52
    -34
      mindinsight/lineagemgr/querier/querier.py

+ 2
- 1
mindinsight/lineagemgr/cache_item_updater.py View File

@@ -18,7 +18,7 @@ import os
from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob
from mindinsight.lineagemgr.common.log import logger
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError
from mindinsight.lineagemgr.common.validator.validate import validate_train_id
from mindinsight.lineagemgr.common.validator.validate import validate_train_id, validate_added_info
from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE
from mindinsight.utils.exceptions import ParamValueError

@@ -26,6 +26,7 @@ from mindinsight.utils.exceptions import ParamValueError
def update_lineage_object(data_manager, train_id, added_info: dict):
"""Update lineage objects about tag and remark."""
validate_train_id(train_id)
validate_added_info(added_info)
cache_item = data_manager.get_brief_train_job(train_id)
lineage_item = cache_item.get(key=LINEAGE, raise_exception=False)
if lineage_item is None:


+ 54
- 2
mindinsight/lineagemgr/common/validator/validate.py View File

@@ -362,8 +362,9 @@ def validate_condition(search_condition):
log.error(err_msg)
raise LineageParamValueError(err_msg)
if not (sorted_name in FIELD_MAPPING
or (sorted_name.startswith('metric/') and len(sorted_name) > 7)
or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13)):
or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/'))
or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/'))
or sorted_name in ['tag']):
log.error(err_msg)
raise LineageParamValueError(err_msg)

@@ -460,3 +461,54 @@ def validate_train_id(relative_path):
raise ParamValueError(
"Summary dir should be relative path starting with './'."
)


def validate_range(name, value, min_value, max_value):
"""
Check if value is in [min_value, max_value].

Args:
name (str): Value name.
value (Union[int, float]): Value to be check.
min_value (Union[int, float]): Min value.
max_value (Union[int, float]): Max value.

Raises:
LineageParamValueError, if value type is invalid or value is out of [min_value, max_value].

"""
if not isinstance(value, (int, float)):
raise LineageParamValueError("Value should be int or float.")

if value < min_value or value > max_value:
raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value))


def validate_added_info(added_info: dict):
"""
Check if added_info is valid.

Args:
added_info (dict): The added info.

Raises:
bool, if added_info is valid, return True.

"""
added_info_keys = ["tag", "remark"]
if not set(added_info.keys()).issubset(added_info_keys):
err_msg = "Keys must be in {}.".format(added_info_keys)
log.error(err_msg)
raise LineageParamValueError(err_msg)

for key, value in added_info.items():
if key == "tag":
if not isinstance(value, int):
raise LineageParamValueError("'tag' must be int.")
# tag should be in [0, 10].
validate_range("tag", value, min_value=0, max_value=10)
elif key == "remark":
if not isinstance(value, str):
raise LineageParamValueError("'remark' must be str.")
# length of remark should be in [0, 128].
validate_range("length of remark", len(value), min_value=0, max_value=128)

+ 52
- 34
mindinsight/lineagemgr/querier/querier.py View File

@@ -271,25 +271,6 @@ class Querier:
return False
return True

def _cmp(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.lineage_obj.get_value_by_key(sorted_name)
value2 = obj2.lineage_obj.get_value_by_key(sorted_name)

if value1 is None and value2 is None:
cmp_result = 0
elif value1 is None:
cmp_result = -1
elif value2 is None:
cmp_result = 1
else:
try:
cmp_result = (value1 > value2) - (value1 < value2)
except TypeError:
type1 = type(value1).__name__
type2 = type(value2).__name__
cmp_result = (type1 > type2) - (type1 < type2)
return cmp_result

if condition is None:
condition = {}

@@ -298,19 +279,7 @@ class Querier:
super_lineage_objs.sort(key=lambda x: x.update_time, reverse=True)

results = list(filter(_filter, super_lineage_objs))

if ConditionParam.SORTED_NAME.value in condition:
sorted_name = condition.get(ConditionParam.SORTED_NAME.value)
if self._is_valid_field(sorted_name):
raise LineageQuerierParamException(
'condition',
'The sorted name {} not supported.'.format(sorted_name)
)
sorted_type = condition.get(ConditionParam.SORTED_TYPE.value)
reverse = sorted_type == 'descending'
results = sorted(
results, key=functools.cmp_to_key(_cmp), reverse=reverse
)
results = self._sorted_results(results, condition)

offset_results = self._handle_limit_and_offset(condition, results)

@@ -338,6 +307,55 @@ class Querier:

return lineage_info

def _sorted_results(self, results, condition):
"""Get sorted results."""
def _cmp(value1, value2):
if value1 is None and value2 is None:
cmp_result = 0
elif value1 is None:
cmp_result = -1
elif value2 is None:
cmp_result = 1
else:
try:
cmp_result = (value1 > value2) - (value1 < value2)
except TypeError:
type1 = type(value1).__name__
type2 = type(value2).__name__
cmp_result = (type1 > type2) - (type1 < type2)
return cmp_result

def _cmp_added_info(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.added_info.get(sorted_name)
value2 = obj2.added_info.get(sorted_name)
return _cmp(value1, value2)

def _cmp_super_lineage_obj(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.lineage_obj.get_value_by_key(sorted_name)
value2 = obj2.lineage_obj.get_value_by_key(sorted_name)

return _cmp(value1, value2)

if ConditionParam.SORTED_NAME.value in condition:
sorted_name = condition.get(ConditionParam.SORTED_NAME.value)
sorted_type = condition.get(ConditionParam.SORTED_TYPE.value)
reverse = sorted_type == 'descending'
if sorted_name in ['tag']:
results = sorted(
results, key=functools.cmp_to_key(_cmp_added_info), reverse=reverse
)
return results

if self._is_valid_field(sorted_name):
raise LineageQuerierParamException(
'condition',
'The sorted name {} not supported.'.format(sorted_name)
)
results = sorted(
results, key=functools.cmp_to_key(_cmp_super_lineage_obj), reverse=reverse
)
return results

def _organize_customized(self, offset_results):
"""Organize customized."""
customized = dict()
@@ -403,8 +421,8 @@ class Querier:
Returns:
bool, `True` if the field name is valid, else `False`.
"""
return field_name not in FIELD_MAPPING and \
not field_name.startswith(('metric/', 'user_defined/'))
return field_name not in FIELD_MAPPING \
and not field_name.startswith(('metric/', 'user_defined/'))

def _handle_limit_and_offset(self, condition, result):
"""


Loading…
Cancel
Save