Merge pull request !27 from kouzhenzhong/user_definedtags/v0.2.0-alpha
| @@ -0,0 +1,129 @@ | |||||
| // Copyright 2020 Huawei Technologies Co., Ltd. | |||||
| // | |||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| // you may not use this file except in compliance with the License. | |||||
| // You may obtain a copy of the License at | |||||
| // | |||||
| // http://www.apache.org/licenses/LICENSE-2.0 | |||||
| // | |||||
| // Unless required by applicable law or agreed to in writing, software | |||||
| // distributed under the License is distributed on an "AS IS" BASIS, | |||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| // See the License for the specific language governing permissions and | |||||
| // limitations under the License. | |||||
| syntax = "proto2"; | |||||
| package mindinsight; | |||||
| option cc_enable_arenas = true; | |||||
| // Event Protocol buffer, Top define | |||||
| message LineageEvent { | |||||
| // Timestamp | |||||
| required double wall_time = 1; | |||||
| // The step of train. | |||||
| optional int64 step = 2; | |||||
| oneof what { | |||||
| // An event file was started, with the specified version. | |||||
| // Now version is "Mindspore.Event:1" | |||||
| string version = 3; | |||||
| // Train lineage | |||||
| TrainLineage train_lineage = 4; | |||||
| // Evaluation lineage | |||||
| EvaluationLineage evaluation_lineage = 5; | |||||
| // Dataset graph | |||||
| DatasetGraph dataset_graph = 6; | |||||
| // User defined info | |||||
| UserDefinedInfo user_defined_info = 7; | |||||
| } | |||||
| } | |||||
| // User defined info | |||||
| message UserDefinedInfo{ | |||||
| // repeated user defined info | |||||
| repeated UserDefinedInfo user_info = 1; | |||||
| // key/value which contains both scalar and dict | |||||
| map<string, UserDefinedInfo> map_dict = 2; | |||||
| map<string, int32> map_int32 = 3; | |||||
| map<string, string> map_str = 4; | |||||
| map<string, double> map_double = 5; | |||||
| } | |||||
| // TrainLineage records infos of a train. | |||||
| message TrainLineage{ | |||||
| message HyperParameters{ | |||||
| optional string optimizer = 1; | |||||
| optional float learning_rate = 2; | |||||
| optional string loss_function = 3; | |||||
| optional int32 epoch = 4; | |||||
| optional string parallel_mode = 5; | |||||
| optional int32 device_num = 6; | |||||
| optional int32 batch_size = 8; | |||||
| } | |||||
| message TrainDataset{ | |||||
| optional string train_dataset_path = 1; | |||||
| optional int32 train_dataset_size = 2; | |||||
| } | |||||
| message Algorithm{ | |||||
| optional string network = 1; | |||||
| optional float loss = 2; | |||||
| } | |||||
| message Model{ | |||||
| optional string path = 3; | |||||
| optional int64 size = 4; | |||||
| } | |||||
| optional HyperParameters hyper_parameters = 1; | |||||
| optional TrainDataset train_dataset = 2; | |||||
| optional Algorithm algorithm = 3; | |||||
| optional Model model = 4; | |||||
| } | |||||
| //EvalLineage records infos of evaluation. | |||||
| message EvaluationLineage{ | |||||
| message ValidDataset{ | |||||
| optional string valid_dataset_path = 1; | |||||
| optional int32 valid_dataset_size = 2; | |||||
| } | |||||
| optional string metric = 2; | |||||
| optional ValidDataset valid_dataset = 3; | |||||
| } | |||||
| // DatasetGraph | |||||
| message DatasetGraph { | |||||
| repeated DatasetGraph children = 1; | |||||
| optional OperationParameter parameter = 2; | |||||
| repeated Operation operations = 3; | |||||
| optional Operation sampler = 4; | |||||
| } | |||||
| message Operation { | |||||
| optional OperationParameter operationParam = 1; | |||||
| repeated int32 size = 2; | |||||
| repeated float weights = 3; | |||||
| } | |||||
| message OperationParameter{ | |||||
| map<string, string> mapStr = 1; | |||||
| map<string, StrList> mapStrList = 2; | |||||
| map<string, bool> mapBool = 3; | |||||
| map<string, int32> mapInt = 4; | |||||
| map<string, double> mapDouble = 5; | |||||
| } | |||||
| message StrList { | |||||
| repeated string strValue = 1; | |||||
| } | |||||
| @@ -39,60 +39,9 @@ message Event { | |||||
| // Summary data | // Summary data | ||||
| Summary summary = 5; | Summary summary = 5; | ||||
| // Train lineage | |||||
| TrainLineage train_lineage = 6; | |||||
| // Evaluation lineage | |||||
| EvaluationLineage evaluation_lineage = 7; | |||||
| // dataset graph | |||||
| DatasetGraph dataset_graph = 9; | |||||
| } | } | ||||
| } | } | ||||
| // TrainLineage records infos of a train. | |||||
| message TrainLineage{ | |||||
| message HyperParameters{ | |||||
| optional string optimizer = 1; | |||||
| optional float learning_rate = 2; | |||||
| optional string loss_function = 3; | |||||
| optional int32 epoch = 4; | |||||
| optional string parallel_mode = 5; | |||||
| optional int32 device_num = 6; | |||||
| optional int32 batch_size = 8; | |||||
| } | |||||
| message TrainDataset{ | |||||
| optional string train_dataset_path = 1; | |||||
| optional int32 train_dataset_size = 2; | |||||
| } | |||||
| message Algorithm{ | |||||
| optional string network = 1; | |||||
| optional float loss = 2; | |||||
| } | |||||
| message Model{ | |||||
| optional string path = 3; | |||||
| optional int64 size = 4; | |||||
| } | |||||
| optional HyperParameters hyper_parameters = 1; | |||||
| optional TrainDataset train_dataset = 2; | |||||
| optional Algorithm algorithm = 3; | |||||
| optional Model model = 4; | |||||
| } | |||||
| //EvalLineage records infos of evaluation. | |||||
| message EvaluationLineage{ | |||||
| message ValidDataset{ | |||||
| optional string valid_dataset_path = 1; | |||||
| optional int32 valid_dataset_size = 2; | |||||
| } | |||||
| optional string metric = 2; | |||||
| optional ValidDataset valid_dataset = 3; | |||||
| } | |||||
| // A Summary is a set of named values that be produced regularly during training | // A Summary is a set of named values that be produced regularly during training | ||||
| message Summary { | message Summary { | ||||
| @@ -127,29 +76,3 @@ message Summary { | |||||
| // Set of values for the summary. | // Set of values for the summary. | ||||
| repeated Value value = 1; | repeated Value value = 1; | ||||
| } | } | ||||
| // DatasetGraph | |||||
| message DatasetGraph { | |||||
| repeated DatasetGraph children = 1; | |||||
| optional OperationParameter parameter = 2; | |||||
| repeated Operation operations = 3; | |||||
| optional Operation sampler = 4; | |||||
| } | |||||
| message Operation { | |||||
| optional OperationParameter operationParam = 1; | |||||
| repeated int32 size = 2; | |||||
| repeated float weights = 3; | |||||
| } | |||||
| message OperationParameter{ | |||||
| map<string, string> mapStr = 1; | |||||
| map<string, StrList> mapStrList = 2; | |||||
| map<string, bool> mapBool = 3; | |||||
| map<string, int32> mapInt = 4; | |||||
| map<string, double> mapDouble = 5; | |||||
| } | |||||
| message StrList { | |||||
| repeated string strValue = 1; | |||||
| } | |||||
| @@ -112,10 +112,10 @@ def filter_summary_lineage(summary_base_dir, search_condition=None): | |||||
| directories generated by training. | directories generated by training. | ||||
| search_condition (dict): The search condition. When filtering and | search_condition (dict): The search condition. When filtering and | ||||
| sorting, in addition to the following supported fields, fields | sorting, in addition to the following supported fields, fields | ||||
| prefixed with `metric_` are also supported. The fields prefixed with | |||||
| `metric_` are related to the `metrics` parameter in the training | |||||
| prefixed with `metric/` are also supported. The fields prefixed with | |||||
| `metric/` are related to the `metrics` parameter in the training | |||||
| script. For example, if the key of `metrics` parameter is | script. For example, if the key of `metrics` parameter is | ||||
| `accuracy`, the field should be `metric_accuracy`. Default: None. | |||||
| `accuracy`, the field should be `metric/accuracy`. Default: None. | |||||
| - summary_dir (dict): The filter condition of summary directory. | - summary_dir (dict): The filter condition of summary directory. | ||||
| @@ -187,7 +187,7 @@ def filter_summary_lineage(summary_base_dir, search_condition=None): | |||||
| >>> 'ge': 128, | >>> 'ge': 128, | ||||
| >>> 'le': 256 | >>> 'le': 256 | ||||
| >>> }, | >>> }, | ||||
| >>> 'metric_accuracy': { | |||||
| >>> 'metric/accuracy': { | |||||
| >>> 'lt': 0.1 | >>> 'lt': 0.1 | ||||
| >>> }, | >>> }, | ||||
| >>> 'sorted_name': 'summary_dir', | >>> 'sorted_name': 'summary_dir', | ||||
| @@ -23,7 +23,8 @@ from mindinsight.utils.exceptions import \ | |||||
| MindInsightException | MindInsightException | ||||
| from mindinsight.lineagemgr.common.validator.validate import validate_train_run_context, \ | from mindinsight.lineagemgr.common.validator.validate import validate_train_run_context, \ | ||||
| validate_eval_run_context, validate_file_path, validate_network, \ | validate_eval_run_context, validate_file_path, validate_network, \ | ||||
| validate_int_params, validate_summary_record, validate_raise_exception | |||||
| validate_int_params, validate_summary_record, validate_raise_exception,\ | |||||
| validate_user_defined_info | |||||
| from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg | from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \ | from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \ | ||||
| LineageGetModelFileError, LineageLogError | LineageGetModelFileError, LineageLogError | ||||
| @@ -71,7 +72,7 @@ class TrainLineage(Callback): | |||||
| >>> lineagemgr = TrainLineage(summary_record=summary_writer) | >>> lineagemgr = TrainLineage(summary_record=summary_writer) | ||||
| >>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr]) | >>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr]) | ||||
| """ | """ | ||||
| def __init__(self, summary_record, raise_exception=False): | |||||
| def __init__(self, summary_record, raise_exception=False, user_defined_info=None): | |||||
| super(TrainLineage, self).__init__() | super(TrainLineage, self).__init__() | ||||
| try: | try: | ||||
| validate_raise_exception(raise_exception) | validate_raise_exception(raise_exception) | ||||
| @@ -85,6 +86,11 @@ class TrainLineage(Callback): | |||||
| self.lineage_log_path = summary_log_path + '_lineage' | self.lineage_log_path = summary_log_path + '_lineage' | ||||
| self.initial_learning_rate = None | self.initial_learning_rate = None | ||||
| self.user_defined_info = user_defined_info | |||||
| if user_defined_info: | |||||
| validate_user_defined_info(user_defined_info) | |||||
| except MindInsightException as err: | except MindInsightException as err: | ||||
| log.error(err) | log.error(err) | ||||
| if raise_exception: | if raise_exception: | ||||
| @@ -104,6 +110,10 @@ class TrainLineage(Callback): | |||||
| """ | """ | ||||
| log.info('Initialize training lineage collection...') | log.info('Initialize training lineage collection...') | ||||
| if self.user_defined_info: | |||||
| lineage_summary = LineageSummary(summary_log_path=self.lineage_log_path) | |||||
| lineage_summary.record_user_defined_info(self.user_defined_info) | |||||
| if not isinstance(run_context, RunContext): | if not isinstance(run_context, RunContext): | ||||
| error_msg = f'Invalid TrainLineage run_context.' | error_msg = f'Invalid TrainLineage run_context.' | ||||
| log.error(error_msg) | log.error(error_msg) | ||||
| @@ -239,7 +249,7 @@ class EvalLineage(Callback): | |||||
| >>> lineagemgr = EvalLineage(summary_record=summary_writer) | >>> lineagemgr = EvalLineage(summary_record=summary_writer) | ||||
| >>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr]) | >>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr]) | ||||
| """ | """ | ||||
| def __init__(self, summary_record, raise_exception=False): | |||||
| def __init__(self, summary_record, raise_exception=False, user_defined_info=None): | |||||
| super(EvalLineage, self).__init__() | super(EvalLineage, self).__init__() | ||||
| try: | try: | ||||
| validate_raise_exception(raise_exception) | validate_raise_exception(raise_exception) | ||||
| @@ -251,6 +261,11 @@ class EvalLineage(Callback): | |||||
| summary_log_path = summary_record.full_file_name | summary_log_path = summary_record.full_file_name | ||||
| validate_file_path(summary_log_path) | validate_file_path(summary_log_path) | ||||
| self.lineage_log_path = summary_log_path + '_lineage' | self.lineage_log_path = summary_log_path + '_lineage' | ||||
| self.user_defined_info = user_defined_info | |||||
| if user_defined_info: | |||||
| validate_user_defined_info(user_defined_info) | |||||
| except MindInsightException as err: | except MindInsightException as err: | ||||
| log.error(err) | log.error(err) | ||||
| if raise_exception: | if raise_exception: | ||||
| @@ -269,6 +284,10 @@ class EvalLineage(Callback): | |||||
| MindInsightException: If validating parameter fails. | MindInsightException: If validating parameter fails. | ||||
| LineageLogError: If recording lineage information fails. | LineageLogError: If recording lineage information fails. | ||||
| """ | """ | ||||
| if self.user_defined_info: | |||||
| lineage_summary = LineageSummary(summary_log_path=self.lineage_log_path) | |||||
| lineage_summary.record_user_defined_info(self.user_defined_info) | |||||
| if not isinstance(run_context, RunContext): | if not isinstance(run_context, RunContext): | ||||
| error_msg = f'Invalid EvalLineage run_context.' | error_msg = f'Invalid EvalLineage run_context.' | ||||
| log.error(error_msg) | log.error(error_msg) | ||||
| @@ -226,7 +226,7 @@ class SearchModelConditionParameter(Schema): | |||||
| if not isinstance(attr, str): | if not isinstance(attr, str): | ||||
| raise LineageParamValueError('The search attribute not supported.') | raise LineageParamValueError('The search attribute not supported.') | ||||
| if attr not in FIELD_MAPPING and not attr.startswith('metric_'): | |||||
| if attr not in FIELD_MAPPING and not attr.startswith(('metric/','user_defined/')): | |||||
| raise LineageParamValueError('The search attribute not supported.') | raise LineageParamValueError('The search attribute not supported.') | ||||
| if not isinstance(condition, dict): | if not isinstance(condition, dict): | ||||
| @@ -238,7 +238,7 @@ class SearchModelConditionParameter(Schema): | |||||
| raise LineageParamValueError("The compare condition should be in " | raise LineageParamValueError("The compare condition should be in " | ||||
| "('eq', 'lt', 'gt', 'le', 'ge', 'in').") | "('eq', 'lt', 'gt', 'le', 'ge', 'in').") | ||||
| if attr.startswith('metric_'): | |||||
| if attr.startswith('metric/'): | |||||
| if len(attr) == 7: | if len(attr) == 7: | ||||
| raise LineageParamValueError( | raise LineageParamValueError( | ||||
| 'The search attribute not supported.' | 'The search attribute not supported.' | ||||
| @@ -349,12 +349,14 @@ def validate_condition(search_condition): | |||||
| if "sorted_name" in search_condition: | if "sorted_name" in search_condition: | ||||
| sorted_name = search_condition.get("sorted_name") | sorted_name = search_condition.get("sorted_name") | ||||
| err_msg = "The sorted_name must be in {} or start with " \ | err_msg = "The sorted_name must be in {} or start with " \ | ||||
| "`metric_`.".format(list(FIELD_MAPPING.keys())) | |||||
| "`metric/` or `user_defined/`.".format(list(FIELD_MAPPING.keys())) | |||||
| if not isinstance(sorted_name, str): | if not isinstance(sorted_name, str): | ||||
| log.error(err_msg) | log.error(err_msg) | ||||
| raise LineageParamValueError(err_msg) | raise LineageParamValueError(err_msg) | ||||
| if sorted_name not in FIELD_MAPPING and not ( | |||||
| sorted_name.startswith('metric_') and len(sorted_name) > 7): | |||||
| if not (sorted_name in FIELD_MAPPING | |||||
| or (sorted_name.startswith('metric/') and len(sorted_name) > 7) | |||||
| or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13) | |||||
| ): | |||||
| log.error(err_msg) | log.error(err_msg) | ||||
| raise LineageParamValueError(err_msg) | raise LineageParamValueError(err_msg) | ||||
| @@ -393,3 +395,38 @@ def validate_path(summary_path): | |||||
| raise LineageDirNotExistError("The summary path does not exist or is not a dir.") | raise LineageDirNotExistError("The summary path does not exist or is not a dir.") | ||||
| return summary_path | return summary_path | ||||
| def validate_user_defined_info(user_defined_info): | |||||
| """ | |||||
| Validate user defined info. | |||||
| Args: | |||||
| user_defined_info (dict): The user defined info. | |||||
| Raises: | |||||
| LineageParamTypeError: If the type of parameters is invalid. | |||||
| LineageParamValueError: If user defined keys have been defined in lineage. | |||||
| """ | |||||
| if not isinstance(user_defined_info, dict): | |||||
| log.error("Invalid user defined info. It should be a dict.") | |||||
| raise LineageParamTypeError("Invalid user defined info. It should be dict.") | |||||
| for key, value in user_defined_info: | |||||
| if not isinstance(key, str): | |||||
| error_msg = "Dict key type {} is not supported in user defined info." \ | |||||
| "Only str is permitted now.".format(type(key)) | |||||
| log.error(error_msg) | |||||
| raise LineageParamTypeError(error_msg) | |||||
| if not isinstance(key, (int, str, float)): | |||||
| error_msg = "Dict value type {} is not supported in user defined info." \ | |||||
| "Only str, int and float are permitted now.".format(type(value)) | |||||
| log.error(error_msg) | |||||
| raise LineageParamTypeError(error_msg) | |||||
| field_map = set(FIELD_MAPPING.keys()) | |||||
| user_defined_keys = set(user_defined_info.keys()) | |||||
| all_keys = field_map | user_defined_keys | |||||
| if len(field_map) + len(user_defined_keys) != len(all_keys): | |||||
| raise LineageParamValueError("There are some keys have defined in lineage.") | |||||
| @@ -236,7 +236,7 @@ class Querier: | |||||
| See `ConditionType` and `ExpressionType` class for the rule of filtering | See `ConditionType` and `ExpressionType` class for the rule of filtering | ||||
| and sorting. The filtering and sorting fields are defined in | and sorting. The filtering and sorting fields are defined in | ||||
| `FIELD_MAPPING` or prefixed with `metric_`. | |||||
| `FIELD_MAPPING` or prefixed with `metric/` or 'user_defined/'. | |||||
| If the condition is `None`, all model lineage information will be | If the condition is `None`, all model lineage information will be | ||||
| returned. | returned. | ||||
| @@ -288,7 +288,7 @@ class Querier: | |||||
| if condition is None: | if condition is None: | ||||
| condition = {} | condition = {} | ||||
| result = list(filter(_filter, self._lineage_objects)) | |||||
| results = list(filter(_filter, self._lineage_objects)) | |||||
| if ConditionParam.SORTED_NAME.value in condition: | if ConditionParam.SORTED_NAME.value in condition: | ||||
| sorted_name = condition.get(ConditionParam.SORTED_NAME.value) | sorted_name = condition.get(ConditionParam.SORTED_NAME.value) | ||||
| @@ -299,19 +299,33 @@ class Querier: | |||||
| ) | ) | ||||
| sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) | sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) | ||||
| reverse = sorted_type == 'descending' | reverse = sorted_type == 'descending' | ||||
| result = sorted( | |||||
| result, key=functools.cmp_to_key(_cmp), reverse=reverse | |||||
| results = sorted( | |||||
| results, key=functools.cmp_to_key(_cmp), reverse=reverse | |||||
| ) | ) | ||||
| offset_result = self._handle_limit_and_offset(condition, result) | |||||
| offset_results = self._handle_limit_and_offset(condition, results) | |||||
| customized = dict() | |||||
| for offset_result in offset_results: | |||||
| for obj_name in ["metric", "user_defined"]: | |||||
| obj = getattr(offset_result, obj_name) | |||||
| if obj and isinstance(obj, dict): | |||||
| for key, value in obj.items(): | |||||
| label = obj_name + "/" + key | |||||
| customized[label] = dict() | |||||
| customized[label]["label"] = label | |||||
| # user defined info is default displayed | |||||
| customized[label]["required"] = True | |||||
| customized[label]["type"] = type(value).__name__ | |||||
| search_type = condition.get(ConditionParam.LINEAGE_TYPE.value) | search_type = condition.get(ConditionParam.LINEAGE_TYPE.value) | ||||
| lineage_info = { | lineage_info = { | ||||
| 'customized': customized, | |||||
| 'object': [ | 'object': [ | ||||
| item.to_dataset_lineage_dict() if search_type == LineageType.DATASET.value | item.to_dataset_lineage_dict() if search_type == LineageType.DATASET.value | ||||
| else item.to_filtration_dict() for item in offset_result | |||||
| else item.to_filtration_dict() for item in offset_results | |||||
| ], | ], | ||||
| 'count': len(result) | |||||
| 'count': len(results) | |||||
| } | } | ||||
| return lineage_info | return lineage_info | ||||
| @@ -326,7 +340,8 @@ class Querier: | |||||
| Returns: | Returns: | ||||
| bool, `True` if the field name is valid, else `False`. | bool, `True` if the field name is valid, else `False`. | ||||
| """ | """ | ||||
| return field_name not in FIELD_MAPPING and not field_name.startswith('metric_') | |||||
| return field_name not in FIELD_MAPPING and \ | |||||
| not field_name.startswith(('metric/', 'user_defined/')) | |||||
| def _handle_limit_and_offset(self, condition, result): | def _handle_limit_and_offset(self, condition, result): | ||||
| """ | """ | ||||
| @@ -397,11 +412,13 @@ class Querier: | |||||
| log_dir = os.path.dirname(log_path) | log_dir = os.path.dirname(log_path) | ||||
| try: | try: | ||||
| lineage_info = LineageSummaryAnalyzer.get_summary_infos(log_path) | lineage_info = LineageSummaryAnalyzer.get_summary_infos(log_path) | ||||
| user_defined_info = LineageSummaryAnalyzer.get_user_defined_info(log_path) | |||||
| lineage_obj = LineageObj( | lineage_obj = LineageObj( | ||||
| log_dir, | log_dir, | ||||
| train_lineage=lineage_info.train_lineage, | train_lineage=lineage_info.train_lineage, | ||||
| evaluation_lineage=lineage_info.eval_lineage, | evaluation_lineage=lineage_info.eval_lineage, | ||||
| dataset_graph=lineage_info.dataset_graph | |||||
| dataset_graph=lineage_info.dataset_graph, | |||||
| user_defined_info=user_defined_info | |||||
| ) | ) | ||||
| self._lineage_objects.append(lineage_obj) | self._lineage_objects.append(lineage_obj) | ||||
| self._add_dataset_mark() | self._add_dataset_mark() | ||||
| @@ -57,6 +57,8 @@ class LineageObj: | |||||
| - dataset_graph (Event): Dataset graph object. | - dataset_graph (Event): Dataset graph object. | ||||
| - user_defined_info (Event): User defined info object. | |||||
| Raises: | Raises: | ||||
| LineageEventNotExistException: If train and evaluation event not exist. | LineageEventNotExistException: If train and evaluation event not exist. | ||||
| LineageEventFieldNotExistException: If the special event field not exist. | LineageEventFieldNotExistException: If the special event field not exist. | ||||
| @@ -72,16 +74,19 @@ class LineageObj: | |||||
| _name_valid_dataset = 'valid_dataset' | _name_valid_dataset = 'valid_dataset' | ||||
| _name_dataset_graph = 'dataset_graph' | _name_dataset_graph = 'dataset_graph' | ||||
| _name_dataset_mark = 'dataset_mark' | _name_dataset_mark = 'dataset_mark' | ||||
| _name_user_defined = 'user_defined' | |||||
| def __init__(self, summary_dir, **kwargs): | def __init__(self, summary_dir, **kwargs): | ||||
| self._lineage_info = { | self._lineage_info = { | ||||
| self._name_summary_dir: summary_dir | self._name_summary_dir: summary_dir | ||||
| } | } | ||||
| user_defined_info_list = kwargs.get('user_defined_info', []) | |||||
| train_lineage = kwargs.get('train_lineage') | train_lineage = kwargs.get('train_lineage') | ||||
| evaluation_lineage = kwargs.get('evaluation_lineage') | evaluation_lineage = kwargs.get('evaluation_lineage') | ||||
| dataset_graph = kwargs.get('dataset_graph') | dataset_graph = kwargs.get('dataset_graph') | ||||
| if not any([train_lineage, evaluation_lineage, dataset_graph]): | if not any([train_lineage, evaluation_lineage, dataset_graph]): | ||||
| raise LineageEventNotExistException() | raise LineageEventNotExistException() | ||||
| self._parse_user_defined_info(user_defined_info_list) | |||||
| self._parse_train_lineage(train_lineage) | self._parse_train_lineage(train_lineage) | ||||
| self._parse_evaluation_lineage(evaluation_lineage) | self._parse_evaluation_lineage(evaluation_lineage) | ||||
| self._parse_dataset_graph(dataset_graph) | self._parse_dataset_graph(dataset_graph) | ||||
| @@ -107,6 +112,16 @@ class LineageObj: | |||||
| """ | """ | ||||
| return self._lineage_info.get(self._name_metric) | return self._lineage_info.get(self._name_metric) | ||||
| @property | |||||
| def user_defined(self): | |||||
| """ | |||||
| Get user defined information. | |||||
| Returns: | |||||
| dict, the user defined information. | |||||
| """ | |||||
| return self._lineage_info.get(self._name_user_defined) | |||||
| @property | @property | ||||
| def hyper_parameters(self): | def hyper_parameters(self): | ||||
| """ | """ | ||||
| @@ -237,19 +252,22 @@ class LineageObj: | |||||
| def get_value_by_key(self, key): | def get_value_by_key(self, key): | ||||
| """ | """ | ||||
| Get the value based on the key in `FIELD_MAPPING` or the key prefixed with `metric_`. | |||||
| Get the value based on the key in `FIELD_MAPPING` or | |||||
| the key prefixed with `metric/` or `user_defined/`. | |||||
| Args: | Args: | ||||
| key (str): The key in `FIELD_MAPPING` or prefixed with `metric_`. | |||||
| key (str): The key in `FIELD_MAPPING` | |||||
| or prefixed with `metric/` or `user_defined/`. | |||||
| Returns: | Returns: | ||||
| object, the value. | object, the value. | ||||
| """ | """ | ||||
| if key.startswith('metric_'): | |||||
| metric_key = key.split('_', 1)[1] | |||||
| metric = self._filtration_result.get(self._name_metric) | |||||
| if metric: | |||||
| return metric.get(metric_key) | |||||
| if key.startswith(('metric/', 'user_defined/')): | |||||
| key_name, sub_key = key.split('/', 1) | |||||
| sub_value_name = self._name_metric if key_name == 'metric' else self._name_user_defined | |||||
| sub_value = self._filtration_result.get(sub_value_name) | |||||
| if sub_value: | |||||
| return sub_value.get(sub_key) | |||||
| return self._filtration_result.get(key) | return self._filtration_result.get(key) | ||||
| def _organize_filtration_result(self): | def _organize_filtration_result(self): | ||||
| @@ -267,6 +285,8 @@ class LineageObj: | |||||
| if field.sub_name else base_attr | if field.sub_name else base_attr | ||||
| # add metric into filtration result | # add metric into filtration result | ||||
| result[self._name_metric] = self.metric | result[self._name_metric] = self.metric | ||||
| result[self._name_user_defined] = self.user_defined | |||||
| # add dataset_graph into filtration result | # add dataset_graph into filtration result | ||||
| result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph) | result[self._name_dataset_graph] = getattr(self, self._name_dataset_graph) | ||||
| return result | return result | ||||
| @@ -342,3 +362,15 @@ class LineageObj: | |||||
| if event_dict is None: | if event_dict is None: | ||||
| raise LineageEventFieldNotExistException(self._name_evaluation_lineage) | raise LineageEventFieldNotExistException(self._name_evaluation_lineage) | ||||
| self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {} | self._lineage_info[self._name_dataset_graph] = event_dict if event_dict else {} | ||||
| def _parse_user_defined_info(self, user_defined_info_list): | |||||
| """ | |||||
| Parse user defined info. | |||||
| Args: | |||||
| user_defined_info_list (list): user defined info list. | |||||
| """ | |||||
| user_defined_infos = dict() | |||||
| for user_defined_info in user_defined_info_list: | |||||
| user_defined_infos.update(user_defined_info) | |||||
| self._lineage_info[self._name_user_defined] = user_defined_infos | |||||
| @@ -15,8 +15,9 @@ | |||||
| """The converter between proto format event of lineage and dict.""" | """The converter between proto format event of lineage and dict.""" | ||||
| import time | import time | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Event | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError | |||||
| from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent, UserDefinedInfo | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError,\ | |||||
| LineageParamValueError | |||||
| from mindinsight.lineagemgr.common.log import logger as log | from mindinsight.lineagemgr.common.log import logger as log | ||||
| @@ -28,9 +29,9 @@ def package_dataset_graph(graph): | |||||
| graph (dict): Dataset graph. | graph (dict): Dataset graph. | ||||
| Returns: | Returns: | ||||
| Event, the proto message event contains dataset graph. | |||||
| LineageEvent, the proto message event contains dataset graph. | |||||
| """ | """ | ||||
| dataset_graph_event = Event() | |||||
| dataset_graph_event = LineageEvent() | |||||
| dataset_graph_event.wall_time = time.time() | dataset_graph_event.wall_time = time.time() | ||||
| dataset_graph = dataset_graph_event.dataset_graph | dataset_graph = dataset_graph_event.dataset_graph | ||||
| @@ -291,3 +292,57 @@ def _organize_parameter(parameter): | |||||
| parameter_result.update(result_str_list_para) | parameter_result.update(result_str_list_para) | ||||
| return parameter_result | return parameter_result | ||||
| def package_user_defined_info(user_dict): | |||||
| """ | |||||
| Package user defined info. | |||||
| Args: | |||||
| user_dict(dict): User defined info dict. | |||||
| Returns: | |||||
| LineageEvent, the proto message event contains user defined info. | |||||
| """ | |||||
| user_event = LineageEvent() | |||||
| user_event.wall_time = time.time() | |||||
| user_defined_info = user_event.user_defined_info | |||||
| _package_user_defined_info(user_dict, user_defined_info) | |||||
| return user_event | |||||
| def _package_user_defined_info(user_defined_dict, user_defined_message): | |||||
| """ | |||||
| Setting attribute in user defined proto message. | |||||
| Args: | |||||
| user_defined_dict (dict): User define info dict. | |||||
| user_defined_message (LineageEvent): Proto message of user defined info. | |||||
| Raises: | |||||
| LineageParamValueError: When the value is out of range. | |||||
| LineageParamTypeError: When given a type not support yet. | |||||
| """ | |||||
| for key, value in user_defined_dict.items(): | |||||
| if not isinstance(key, str): | |||||
| raise LineageParamTypeError("The key must be str.") | |||||
| if isinstance(value, int): | |||||
| attr_name = "map_int32" | |||||
| elif isinstance(value, float): | |||||
| attr_name = "map_double" | |||||
| elif isinstance(value, str): | |||||
| attr_name = "map_str" | |||||
| else: | |||||
| error_msg = "Value type {} is not supported in user defined event package." \ | |||||
| "Only str, int and float are permitted now.".format(type(value)) | |||||
| log.error(error_msg) | |||||
| raise LineageParamTypeError(error_msg) | |||||
| add_user_defined_info = user_defined_message.user_info.add() | |||||
| try: | |||||
| getattr(add_user_defined_info, attr_name)[key] = value | |||||
| except ValueError: | |||||
| raise LineageParamValueError("Value is out of range or not be supported yet.") | |||||
| @@ -17,7 +17,9 @@ import struct | |||||
| from collections import namedtuple | from collections import namedtuple | ||||
| from enum import Enum | from enum import Enum | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Event | |||||
| from google.protobuf.json_format import MessageToDict | |||||
| from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import MindInsightException, \ | from mindinsight.lineagemgr.common.exceptions.exceptions import MindInsightException, \ | ||||
| LineageVerificationException, LineageSummaryAnalyzeException | LineageVerificationException, LineageSummaryAnalyzeException | ||||
| @@ -87,11 +89,11 @@ class SummaryAnalyzer: | |||||
| Read event. | Read event. | ||||
| Returns: | Returns: | ||||
| Event, the event body. | |||||
| LineageEvent, the event body. | |||||
| """ | """ | ||||
| body_size = self._read_header() | body_size = self._read_header() | ||||
| body_str = self._read_body(body_size) | body_str = self._read_body(body_size) | ||||
| event = Event().FromString(body_str) | |||||
| event = LineageEvent().FromString(body_str) | |||||
| return event | return event | ||||
| def _read_header(self): | def _read_header(self): | ||||
| @@ -206,3 +208,51 @@ class LineageSummaryAnalyzer(SummaryAnalyzer): | |||||
| raise LineageSummaryAnalyzeException() | raise LineageSummaryAnalyzeException() | ||||
| return lineage_info | return lineage_info | ||||
| @staticmethod | |||||
| def get_user_defined_info(file_path): | |||||
| """ | |||||
| Get user defined info. | |||||
| Args: | |||||
| file_path (str): The file path of summary log. | |||||
| Returns: | |||||
| list, the list of dict format user defined information | |||||
| which converted from proto message. | |||||
| """ | |||||
| all_user_message = [] | |||||
| summary_analyzer = SummaryAnalyzer(file_path) | |||||
| for event in summary_analyzer.load_events(): | |||||
| if event.HasField("user_defined_info"): | |||||
| user_defined_info = MessageToDict( | |||||
| event, | |||||
| preserving_proto_field_name=True | |||||
| ).get("user_defined_info") | |||||
| user_dict = LineageSummaryAnalyzer._get_dict_from_proto(user_defined_info) | |||||
| all_user_message.append(user_dict) | |||||
| return all_user_message | |||||
| @staticmethod | |||||
| def _get_dict_from_proto(user_defined_info): | |||||
| """ | |||||
| Convert the proto message UserDefinedInfo to its dict format. | |||||
| Args: | |||||
| user_defined_info (UserDefinedInfo): The proto message of user defined info. | |||||
| Returns: | |||||
| dict, the converted dict. | |||||
| """ | |||||
| user_dict = dict() | |||||
| proto_dict = user_defined_info.get("user_info") | |||||
| for proto_item in proto_dict: | |||||
| if proto_item and isinstance(proto_item, dict): | |||||
| key, value = list(list(proto_item.values())[0].items())[0] | |||||
| if isinstance(value, dict): | |||||
| user_dict[key] = LineageSummaryAnalyzer._get_dict_from_proto(value) | |||||
| else: | |||||
| user_dict[key] = value | |||||
| return user_dict | |||||
| @@ -15,9 +15,9 @@ | |||||
| """Record message to summary log.""" | """Record message to summary log.""" | ||||
| import time | import time | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Event | |||||
| from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent | |||||
| from mindinsight.lineagemgr.summary.event_writer import EventWriter | from mindinsight.lineagemgr.summary.event_writer import EventWriter | ||||
| from ._summary_adapter import package_dataset_graph | |||||
| from ._summary_adapter import package_dataset_graph, package_user_defined_info | |||||
| class LineageSummary: | class LineageSummary: | ||||
| @@ -50,9 +50,9 @@ class LineageSummary: | |||||
| run_context_args (dict): The train lineage info to log. | run_context_args (dict): The train lineage info to log. | ||||
| Returns: | Returns: | ||||
| Event, the proto message event contains train lineage. | |||||
| LineageEvent, the proto message event contains train lineage. | |||||
| """ | """ | ||||
| train_lineage_event = Event() | |||||
| train_lineage_event = LineageEvent() | |||||
| train_lineage_event.wall_time = time.time() | train_lineage_event.wall_time = time.time() | ||||
| # Init train_lineage message. | # Init train_lineage message. | ||||
| @@ -124,9 +124,9 @@ class LineageSummary: | |||||
| run_context_args (dict): The evaluation lineage info to log. | run_context_args (dict): The evaluation lineage info to log. | ||||
| Returns: | Returns: | ||||
| Event, the proto message event contains evaluation lineage. | |||||
| LineageEvent, the proto message event contains evaluation lineage. | |||||
| """ | """ | ||||
| train_lineage_event = Event() | |||||
| train_lineage_event = LineageEvent() | |||||
| train_lineage_event.wall_time = time.time() | train_lineage_event.wall_time = time.time() | ||||
| # Init evaluation_lineage message. | # Init evaluation_lineage message. | ||||
| @@ -165,3 +165,18 @@ class LineageSummary: | |||||
| self.event_writer.write_event_to_file( | self.event_writer.write_event_to_file( | ||||
| package_dataset_graph(dataset_graph).SerializeToString() | package_dataset_graph(dataset_graph).SerializeToString() | ||||
| ) | ) | ||||
| def record_user_defined_info(self, user_dict): | |||||
| """ | |||||
| Write user defined info to summary log. | |||||
| Note: | |||||
| The type of references must be dict, the value should be | |||||
| int32, float, string. Nested dict is not supported now. | |||||
| Args: | |||||
| user_dict (dict): The value user defined to be recorded. | |||||
| """ | |||||
| self.event_writer.write_event_to_file( | |||||
| package_user_defined_info(user_dict).SerializeToString() | |||||
| ) | |||||
| @@ -32,7 +32,8 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF | |||||
| LineageSearchConditionParamError) | LineageSearchConditionParamError) | ||||
| from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | ||||
| from .....ut.lineagemgr.querier import event_data | |||||
| from os import environ | |||||
| LINEAGE_INFO_RUN1 = { | LINEAGE_INFO_RUN1 = { | ||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | ||||
| 'metric': { | 'metric': { | ||||
| @@ -68,6 +69,7 @@ LINEAGE_FILTRATION_EXCEPT_RUN = { | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | 'loss_function': 'SoftmaxCrossEntropyWithLogits', | ||||
| 'train_dataset_path': None, | 'train_dataset_path': None, | ||||
| 'train_dataset_count': 1024, | 'train_dataset_count': 1024, | ||||
| 'user_defined': {}, | |||||
| 'test_dataset_path': None, | 'test_dataset_path': None, | ||||
| 'test_dataset_count': None, | 'test_dataset_count': None, | ||||
| 'network': 'ResNet', | 'network': 'ResNet', | ||||
| @@ -87,6 +89,7 @@ LINEAGE_FILTRATION_RUN1 = { | |||||
| 'train_dataset_path': None, | 'train_dataset_path': None, | ||||
| 'train_dataset_count': 731, | 'train_dataset_count': 731, | ||||
| 'test_dataset_path': None, | 'test_dataset_path': None, | ||||
| 'user_defined': {}, | |||||
| 'test_dataset_count': 10240, | 'test_dataset_count': 10240, | ||||
| 'network': 'ResNet', | 'network': 'ResNet', | ||||
| 'optimizer': 'Momentum', | 'optimizer': 'Momentum', | ||||
| @@ -106,6 +109,7 @@ LINEAGE_FILTRATION_RUN2 = { | |||||
| 'loss_function': None, | 'loss_function': None, | ||||
| 'train_dataset_path': None, | 'train_dataset_path': None, | ||||
| 'train_dataset_count': None, | 'train_dataset_count': None, | ||||
| 'user_defined': {}, | |||||
| 'test_dataset_path': None, | 'test_dataset_path': None, | ||||
| 'test_dataset_count': 10240, | 'test_dataset_count': 10240, | ||||
| 'network': None, | 'network': None, | ||||
| @@ -318,6 +322,7 @@ class TestModelApi(TestCase): | |||||
| def test_filter_summary_lineage(self): | def test_filter_summary_lineage(self): | ||||
| """Test the interface of filter_summary_lineage.""" | """Test the interface of filter_summary_lineage.""" | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': event_data.CUSTOMIZED__0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_EXCEPT_RUN, | LINEAGE_FILTRATION_EXCEPT_RUN, | ||||
| LINEAGE_FILTRATION_RUN1, | LINEAGE_FILTRATION_RUN1, | ||||
| @@ -360,16 +365,17 @@ class TestModelApi(TestCase): | |||||
| SUMMARY_DIR_2 | SUMMARY_DIR_2 | ||||
| ] | ] | ||||
| }, | }, | ||||
| 'metric_accuracy': { | |||||
| 'metric/accuracy': { | |||||
| 'lt': 3.0, | 'lt': 3.0, | ||||
| 'gt': 0.5 | 'gt': 0.5 | ||||
| }, | }, | ||||
| 'sorted_name': 'metric_accuracy', | |||||
| 'sorted_name': 'metric/accuracy', | |||||
| 'sorted_type': 'descending', | 'sorted_type': 'descending', | ||||
| 'limit': 3, | 'limit': 3, | ||||
| 'offset': 0 | 'offset': 0 | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': event_data.CUSTOMIZED__0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_RUN2, | LINEAGE_FILTRATION_RUN2, | ||||
| LINEAGE_FILTRATION_RUN1 | LINEAGE_FILTRATION_RUN1 | ||||
| @@ -397,16 +403,17 @@ class TestModelApi(TestCase): | |||||
| './run2' | './run2' | ||||
| ] | ] | ||||
| }, | }, | ||||
| 'metric_accuracy': { | |||||
| 'metric/accuracy': { | |||||
| 'lt': 3.0, | 'lt': 3.0, | ||||
| 'gt': 0.5 | 'gt': 0.5 | ||||
| }, | }, | ||||
| 'sorted_name': 'metric_accuracy', | |||||
| 'sorted_name': 'metric/accuracy', | |||||
| 'sorted_type': 'descending', | 'sorted_type': 'descending', | ||||
| 'limit': 3, | 'limit': 3, | ||||
| 'offset': 0 | 'offset': 0 | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': event_data.CUSTOMIZED__0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_RUN2, | LINEAGE_FILTRATION_RUN2, | ||||
| LINEAGE_FILTRATION_RUN1 | LINEAGE_FILTRATION_RUN1 | ||||
| @@ -431,10 +438,11 @@ class TestModelApi(TestCase): | |||||
| 'batch_size': { | 'batch_size': { | ||||
| 'ge': 30 | 'ge': 30 | ||||
| }, | }, | ||||
| 'sorted_name': 'metric_accuracy', | |||||
| 'sorted_name': 'metric/accuracy', | |||||
| 'lineage_type': None | 'lineage_type': None | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': event_data.CUSTOMIZED__0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_EXCEPT_RUN, | LINEAGE_FILTRATION_EXCEPT_RUN, | ||||
| LINEAGE_FILTRATION_RUN1 | LINEAGE_FILTRATION_RUN1 | ||||
| @@ -454,6 +462,7 @@ class TestModelApi(TestCase): | |||||
| 'lineage_type': 'model' | 'lineage_type': 'model' | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': {}, | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 0 | 'count': 0 | ||||
| } | } | ||||
| @@ -479,6 +488,7 @@ class TestModelApi(TestCase): | |||||
| 'lineage_type': 'dataset' | 'lineage_type': 'dataset' | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': {}, | |||||
| 'object': [ | 'object': [ | ||||
| { | { | ||||
| 'summary_dir': summary_dir, | 'summary_dir': summary_dir, | ||||
| @@ -659,13 +669,13 @@ class TestModelApi(TestCase): | |||||
| # the search condition type error | # the search condition type error | ||||
| search_condition = { | search_condition = { | ||||
| 'metric_accuracy': { | |||||
| 'metric/accuracy': { | |||||
| 'lt': 'xxx' | 'lt': 'xxx' | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | self.assertRaisesRegex( | ||||
| LineageSearchConditionParamError, | LineageSearchConditionParamError, | ||||
| 'The parameter metric_accuracy is invalid.', | |||||
| 'The parameter metric/accuracy is invalid.', | |||||
| filter_summary_lineage, | filter_summary_lineage, | ||||
| BASE_SUMMARY_DIR, | BASE_SUMMARY_DIR, | ||||
| search_condition | search_condition | ||||
| @@ -741,12 +751,13 @@ class TestModelApi(TestCase): | |||||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | """Test the abnormal execution of the filter_summary_lineage interface.""" | ||||
| # gt > lt | # gt > lt | ||||
| search_condition1 = { | search_condition1 = { | ||||
| 'metric_accuracy': { | |||||
| 'metric/accuracy': { | |||||
| 'gt': 1, | 'gt': 1, | ||||
| 'lt': 0.5 | 'lt': 0.5 | ||||
| } | } | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': {}, | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 0 | 'count': 0 | ||||
| } | } | ||||
| @@ -762,6 +773,7 @@ class TestModelApi(TestCase): | |||||
| 'offset': 4 | 'offset': 4 | ||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'customized': {}, | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 1 | 'count': 1 | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ LINEAGE_FILTRATION_BASE = { | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | 'loss_function': 'SoftmaxCrossEntropyWithLogits', | ||||
| 'train_dataset_path': None, | 'train_dataset_path': None, | ||||
| 'train_dataset_count': 64, | 'train_dataset_count': 64, | ||||
| 'user_defined': {}, | |||||
| 'test_dataset_path': None, | 'test_dataset_path': None, | ||||
| 'test_dataset_count': None, | 'test_dataset_count': None, | ||||
| 'network': 'str', | 'network': 'str', | ||||
| @@ -46,6 +47,7 @@ LINEAGE_FILTRATION_RUN1 = { | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | 'loss_function': 'SoftmaxCrossEntropyWithLogits', | ||||
| 'train_dataset_path': None, | 'train_dataset_path': None, | ||||
| 'train_dataset_count': 64, | 'train_dataset_count': 64, | ||||
| 'user_defined': {}, | |||||
| 'test_dataset_path': None, | 'test_dataset_path': None, | ||||
| 'test_dataset_count': 64, | 'test_dataset_count': 64, | ||||
| 'network': 'str', | 'network': 'str', | ||||
| @@ -228,12 +228,12 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| ) | ) | ||||
| condition = { | condition = { | ||||
| 'metric_attribute': { | |||||
| 'metric/attribute': { | |||||
| 'ge': 'xxx' | 'ge': 'xxx' | ||||
| } | } | ||||
| } | } | ||||
| self._assert_raise_of_mindinsight_exception( | self._assert_raise_of_mindinsight_exception( | ||||
| "The parameter metric_attribute is invalid. " | |||||
| "The parameter metric/attribute is invalid. " | |||||
| "It should be a dict and the value should be a float or a integer", | "It should be a dict and the value should be a float or a integer", | ||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -188,6 +188,22 @@ METRIC_0 = { | |||||
| 'mse': 3.00000001 | 'mse': 3.00000001 | ||||
| } | } | ||||
| CUSTOMIZED__0 = { | |||||
| 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, | |||||
| } | |||||
| CUSTOMIZED_0 = { | |||||
| **CUSTOMIZED__0, | |||||
| 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, | |||||
| 'metric/mse': {'label': 'metric/mse', 'required': True, 'type': 'float'} | |||||
| } | |||||
| CUSTOMIZED_1 = { | |||||
| 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'NoneType'}, | |||||
| 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, | |||||
| 'metric/mse': {'label': 'metric/mse', 'required': True, 'type': 'float'} | |||||
| } | |||||
| METRIC_1 = { | METRIC_1 = { | ||||
| 'accuracy': 1.0000002, | 'accuracy': 1.0000002, | ||||
| 'mae': 2.00000002, | 'mae': 2.00000002, | ||||
| @@ -17,7 +17,7 @@ from unittest import TestCase, mock | |||||
| from google.protobuf.json_format import ParseDict | from google.protobuf.json_format import ParseDict | ||||
| import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 | |||||
| import mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 as summary_pb2 | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException, | from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException, | ||||
| LineageSummaryAnalyzeException, | LineageSummaryAnalyzeException, | ||||
| LineageSummaryParseException) | LineageSummaryParseException) | ||||
| @@ -40,19 +40,19 @@ def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): | |||||
| namedtuple, parsed lineage info. | namedtuple, parsed lineage info. | ||||
| """ | """ | ||||
| if train_event_dict is not None: | if train_event_dict is not None: | ||||
| train_event = summary_pb2.Event() | |||||
| train_event = summary_pb2.LineageEvent() | |||||
| ParseDict(train_event_dict, train_event) | ParseDict(train_event_dict, train_event) | ||||
| else: | else: | ||||
| train_event = None | train_event = None | ||||
| if eval_event_dict is not None: | if eval_event_dict is not None: | ||||
| eval_event = summary_pb2.Event() | |||||
| eval_event = summary_pb2.LineageEvent() | |||||
| ParseDict(eval_event_dict, eval_event) | ParseDict(eval_event_dict, eval_event) | ||||
| else: | else: | ||||
| eval_event = None | eval_event = None | ||||
| if dataset_event_dict is not None: | if dataset_event_dict is not None: | ||||
| dataset_event = summary_pb2.Event() | |||||
| dataset_event = summary_pb2.LineageEvent() | |||||
| ParseDict(dataset_event_dict, dataset_event) | ParseDict(dataset_event_dict, dataset_event) | ||||
| else: | else: | ||||
| dataset_event = None | dataset_event = None | ||||
| @@ -97,6 +97,7 @@ def create_filtration_result(summary_dir, train_event_dict, | |||||
| "metric": metric_dict, | "metric": metric_dict, | ||||
| "dataset_graph": dataset_dict, | "dataset_graph": dataset_dict, | ||||
| "dataset_mark": '2', | "dataset_mark": '2', | ||||
| "user_defined": {} | |||||
| } | } | ||||
| return filtration_result | return filtration_result | ||||
| @@ -208,7 +209,9 @@ LINEAGE_FILTRATION_5 = { | |||||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | ||||
| "metric": {}, | "metric": {}, | ||||
| "dataset_graph": event_data.DATASET_DICT_0, | "dataset_graph": event_data.DATASET_DICT_0, | ||||
| "dataset_mark": '2' | |||||
| "dataset_mark": '2', | |||||
| "user_defined": {} | |||||
| } | } | ||||
| LINEAGE_FILTRATION_6 = { | LINEAGE_FILTRATION_6 = { | ||||
| "summary_dir": '/path/to/summary6', | "summary_dir": '/path/to/summary6', | ||||
| @@ -228,12 +231,14 @@ LINEAGE_FILTRATION_6 = { | |||||
| "model_size": None, | "model_size": None, | ||||
| "metric": event_data.METRIC_5, | "metric": event_data.METRIC_5, | ||||
| "dataset_graph": event_data.DATASET_DICT_0, | "dataset_graph": event_data.DATASET_DICT_0, | ||||
| "dataset_mark": '2' | |||||
| "dataset_mark": '2', | |||||
| "user_defined": {} | |||||
| } | } | ||||
| class TestQuerier(TestCase): | class TestQuerier(TestCase): | ||||
| """Test the class of `Querier`.""" | """Test the class of `Querier`.""" | ||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_user_defined_info') | |||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | ||||
| def setUp(self, *args): | def setUp(self, *args): | ||||
| """Initialization before test case execution.""" | """Initialization before test case execution.""" | ||||
| @@ -242,6 +247,7 @@ class TestQuerier(TestCase): | |||||
| event_data.EVENT_EVAL_DICT_0, | event_data.EVENT_EVAL_DICT_0, | ||||
| event_data.EVENT_DATASET_DICT_0 | event_data.EVENT_DATASET_DICT_0 | ||||
| ) | ) | ||||
| args[1].return_value = [] | |||||
| single_summary_path = '/path/to/summary0/log0' | single_summary_path = '/path/to/summary0/log0' | ||||
| self.single_querier = Querier(single_summary_path) | self.single_querier = Querier(single_summary_path) | ||||
| @@ -394,6 +400,7 @@ class TestQuerier(TestCase): | |||||
| 'sorted_name': 'summary_dir' | 'sorted_name': 'summary_dir' | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_1, | LINEAGE_FILTRATION_1, | ||||
| LINEAGE_FILTRATION_2 | LINEAGE_FILTRATION_2 | ||||
| @@ -418,6 +425,7 @@ class TestQuerier(TestCase): | |||||
| 'sorted_type': 'descending' | 'sorted_type': 'descending' | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_2, | LINEAGE_FILTRATION_2, | ||||
| LINEAGE_FILTRATION_3 | LINEAGE_FILTRATION_3 | ||||
| @@ -434,6 +442,7 @@ class TestQuerier(TestCase): | |||||
| 'offset': 1 | 'offset': 1 | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_2, | LINEAGE_FILTRATION_2, | ||||
| LINEAGE_FILTRATION_3 | LINEAGE_FILTRATION_3 | ||||
| @@ -446,6 +455,7 @@ class TestQuerier(TestCase): | |||||
| def test_filter_summary_lineage_success_4(self): | def test_filter_summary_lineage_success_4(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_0, | LINEAGE_FILTRATION_0, | ||||
| LINEAGE_FILTRATION_1, | LINEAGE_FILTRATION_1, | ||||
| @@ -468,6 +478,7 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [LINEAGE_FILTRATION_4], | 'object': [LINEAGE_FILTRATION_4], | ||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| @@ -477,10 +488,11 @@ class TestQuerier(TestCase): | |||||
| def test_filter_summary_lineage_success_6(self): | def test_filter_summary_lineage_success_6(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| condition = { | condition = { | ||||
| 'sorted_name': 'metric_accuracy', | |||||
| 'sorted_name': 'metric/accuracy', | |||||
| 'sorted_type': 'ascending' | 'sorted_type': 'ascending' | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_0, | LINEAGE_FILTRATION_0, | ||||
| LINEAGE_FILTRATION_5, | LINEAGE_FILTRATION_5, | ||||
| @@ -498,10 +510,11 @@ class TestQuerier(TestCase): | |||||
| def test_filter_summary_lineage_success_7(self): | def test_filter_summary_lineage_success_7(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| condition = { | condition = { | ||||
| 'sorted_name': 'metric_accuracy', | |||||
| 'sorted_name': 'metric/accuracy', | |||||
| 'sorted_type': 'descending' | 'sorted_type': 'descending' | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_1, | |||||
| 'object': [ | 'object': [ | ||||
| LINEAGE_FILTRATION_6, | LINEAGE_FILTRATION_6, | ||||
| LINEAGE_FILTRATION_4, | LINEAGE_FILTRATION_4, | ||||
| @@ -519,12 +532,13 @@ class TestQuerier(TestCase): | |||||
| def test_filter_summary_lineage_success_8(self): | def test_filter_summary_lineage_success_8(self): | ||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| condition = { | condition = { | ||||
| 'metric_accuracy': { | |||||
| 'metric/accuracy': { | |||||
| 'lt': 1.0000006, | 'lt': 1.0000006, | ||||
| 'gt': 1.0000004 | 'gt': 1.0000004 | ||||
| } | } | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': event_data.CUSTOMIZED_0, | |||||
| 'object': [LINEAGE_FILTRATION_4], | 'object': [LINEAGE_FILTRATION_4], | ||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| @@ -538,6 +552,7 @@ class TestQuerier(TestCase): | |||||
| 'offset': 3 | 'offset': 3 | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'customized': {}, | |||||
| 'object': [], | 'object': [], | ||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| @@ -594,11 +609,13 @@ class TestQuerier(TestCase): | |||||
| with self.assertRaises(LineageSummaryParseException): | with self.assertRaises(LineageSummaryParseException): | ||||
| Querier(summary_path) | Querier(summary_path) | ||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_user_defined_info') | |||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | ||||
| def test_parse_fail_summary_logs_1(self, *args): | def test_parse_fail_summary_logs_1(self, *args): | ||||
| """Test the function of parsing fail summary logs.""" | """Test the function of parsing fail summary logs.""" | ||||
| lineage_infos = get_lineage_infos() | lineage_infos = get_lineage_infos() | ||||
| args[0].side_effect = lineage_infos | args[0].side_effect = lineage_infos | ||||
| args[1].return_value = [] | |||||
| summary_path = ['/path/to/summary0/log0'] | summary_path = ['/path/to/summary0/log0'] | ||||
| querier = Querier(summary_path) | querier = Querier(summary_path) | ||||
| @@ -611,6 +628,7 @@ class TestQuerier(TestCase): | |||||
| self.assertListEqual(expected_result, result) | self.assertListEqual(expected_result, result) | ||||
| self.assertListEqual([], querier._parse_failed_paths) | self.assertListEqual([], querier._parse_failed_paths) | ||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_user_defined_info') | |||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | ||||
| def test_parse_fail_summary_logs_2(self, *args): | def test_parse_fail_summary_logs_2(self, *args): | ||||
| """Test the function of parsing fail summary logs.""" | """Test the function of parsing fail summary logs.""" | ||||
| @@ -619,6 +637,7 @@ class TestQuerier(TestCase): | |||||
| event_data.EVENT_EVAL_DICT_0, | event_data.EVENT_EVAL_DICT_0, | ||||
| event_data.EVENT_DATASET_DICT_0, | event_data.EVENT_DATASET_DICT_0, | ||||
| ) | ) | ||||
| args[1].return_value = [] | |||||
| summary_path = ['/path/to/summary0/log0'] | summary_path = ['/path/to/summary0/log0'] | ||||
| querier = Querier(summary_path) | querier = Querier(summary_path) | ||||
| @@ -16,7 +16,7 @@ | |||||
| from unittest import mock, TestCase | from unittest import mock, TestCase | ||||
| from unittest.mock import MagicMock | from unittest.mock import MagicMock | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Event | |||||
| from mindinsight.datavisual.proto_files.mindinsight_lineage_pb2 import LineageEvent | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageVerificationException, \ | from mindinsight.lineagemgr.common.exceptions.exceptions import LineageVerificationException, \ | ||||
| LineageSummaryAnalyzeException | LineageSummaryAnalyzeException | ||||
| from mindinsight.lineagemgr.common.log import logger as log | from mindinsight.lineagemgr.common.log import logger as log | ||||
| @@ -57,7 +57,7 @@ class TestSummaryAnalyzer(TestCase): | |||||
| @mock.patch.object(SummaryAnalyzer, '_read_header') | @mock.patch.object(SummaryAnalyzer, '_read_header') | ||||
| @mock.patch.object(SummaryAnalyzer, '_read_body') | @mock.patch.object(SummaryAnalyzer, '_read_body') | ||||
| @mock.patch.object(Event, 'FromString') | |||||
| @mock.patch.object(LineageEvent, 'FromString') | |||||
| def test_read_event(self, *args): | def test_read_event(self, *args): | ||||
| """Test read_event method.""" | """Test read_event method.""" | ||||
| args[2].return_value = 10 | args[2].return_value = 10 | ||||