| @@ -284,8 +284,8 @@ class EvalLineage(Callback): | |||
| self.lineage_summary = LineageSummary(self.lineage_log_dir) | |||
| self.user_defined_info = user_defined_info | |||
| if user_defined_info: | |||
| validate_user_defined_info(user_defined_info) | |||
| if self.user_defined_info: | |||
| validate_user_defined_info(self.user_defined_info) | |||
| except MindInsightException as err: | |||
| log.error(err) | |||
| @@ -410,7 +410,7 @@ def validate_path(summary_path): | |||
| def validate_user_defined_info(user_defined_info): | |||
| """ | |||
| Validate user defined info. | |||
| Validate user defined info, delete the item if its key is in lineage. | |||
| Args: | |||
| user_defined_info (dict): The user defined info. | |||
| @@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info): | |||
| field_map = set(FIELD_MAPPING.keys()) | |||
| user_defined_keys = set(user_defined_info.keys()) | |||
| all_keys = field_map | user_defined_keys | |||
| insertion = list(field_map & user_defined_keys) | |||
| if len(field_map) + len(user_defined_keys) != len(all_keys): | |||
| raise LineageParamValueError("There are some keys have defined in lineage.") | |||
| if insertion: | |||
| for key in insertion: | |||
| user_defined_info.pop(key) | |||
| raise LineageParamValueError("There are some keys have defined in lineage. " | |||
| "Duplicated key(s): %s. " % insertion) | |||
| def validate_train_id(relative_path): | |||
| @@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = { | |||
| 'train_dataset_count': 1024, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 1024, | |||
| 'user_defined': {}, | |||
| 'user_defined': {'info': 'info1', 'version': 'v1'}, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| @@ -329,7 +329,7 @@ class TestModelApi(TestCase): | |||
| def test_filter_summary_lineage(self): | |||
| """Test the interface of filter_summary_lineage.""" | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -383,7 +383,7 @@ class TestModelApi(TestCase): | |||
| 'offset': 0 | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| @@ -421,7 +421,7 @@ class TestModelApi(TestCase): | |||
| 'offset': 0 | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| @@ -449,7 +449,7 @@ class TestModelApi(TestCase): | |||
| 'sorted_name': 'metric/accuracy', | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -70,7 +70,7 @@ class TestModelApi(TestCase): | |||
| def test_filter_summary_lineage(self): | |||
| """Test the interface of filter_summary_lineage.""" | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -28,7 +28,7 @@ from unittest import mock, TestCase | |||
| import numpy as np | |||
| import pytest | |||
| from mindinsight.lineagemgr import get_summary_lineage | |||
| from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage | |||
| from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ | |||
| AnalyzeObject | |||
| from mindinsight.lineagemgr.common.utils import make_directory | |||
| @@ -109,6 +109,36 @@ class TestModelLineage(TestCase): | |||
| lineage_log_path = train_callback.lineage_summary.lineage_log_path | |||
| assert os.path.isfile(lineage_log_path) is True | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_train_begin_with_user_defined_key_in_lineage(self): | |||
| """Test TrainLineage with nested user defined info.""" | |||
| expected_res = { | |||
| "info": "info1", | |||
| "version": "v1" | |||
| } | |||
| user_defined_info = { | |||
| "info": "info1", | |||
| "version": "v1", | |||
| "network": "LeNet" | |||
| } | |||
| train_callback = TrainLineage( | |||
| self.summary_record, | |||
| False, | |||
| user_defined_info | |||
| ) | |||
| train_callback.begin(RunContext(self.run_context)) | |||
| assert train_callback.initial_learning_rate == 0.12 | |||
| lineage_log_path = train_callback.lineage_summary.lineage_log_path | |||
| assert os.path.isfile(lineage_log_path) is True | |||
| res = filter_summary_lineage(os.path.dirname(lineage_log_path)) | |||
| assert expected_res == res['object'][0]['model_lineage']['user_defined'] | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -192,6 +192,12 @@ CUSTOMIZED__0 = { | |||
| 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, | |||
| } | |||
| CUSTOMIZED__1 = { | |||
| **CUSTOMIZED__0, | |||
| 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, | |||
| 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'} | |||
| } | |||
| CUSTOMIZED_0 = { | |||
| **CUSTOMIZED__0, | |||
| 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, | |||