diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index bafc2484..ba5b6277 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -108,16 +108,7 @@ class TestModelLineage(TestCase): res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 - @pytest.mark.scene_eval(3) - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend_training - @pytest.mark.platform_x86_gpu_training - @pytest.mark.platform_x86_ascend_training - @pytest.mark.platform_x86_cpu - @pytest.mark.env_single - def test_eval_end(self): - """Test the end function in EvalLineage.""" - eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'}) + eval_callback = EvalLineage(self.summary_record, True, self.user_defined_info) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] @@ -152,7 +143,7 @@ class TestModelLineage(TestCase): SUMMARY_DIR_2, f'train_out.events.summary.{str(int(time.time()) + 2*i)}.ubuntu_lineage' ) - train_callback = TrainLineage(summary_record, True) + train_callback = TrainLineage(summary_record, True, self.user_defined_info) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) @@ -160,7 +151,7 @@ class TestModelLineage(TestCase): SUMMARY_DIR_2, f'eval_out.events.summary.{str(int(time.time())+ 2*i + 1)}.ubuntu_lineage' ) - eval_callback = EvalLineage(eval_record, True) + eval_callback = EvalLineage(eval_record, True, {'eval_version': 'version2'}) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78 + i + 1} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] @@ -169,7 +160,7 @@ class TestModelLineage(TestCase): file_num = os.listdir(SUMMARY_DIR_2) assert len(file_num) == 8 - @pytest.mark.scene_train(2) + @pytest.mark.scene_train(3) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_gpu_training diff --git a/tests/st/func/lineagemgr/test_model.py b/tests/st/func/lineagemgr/test_model.py index 89a0eb2c..8fbcba7a 100644 --- a/tests/st/func/lineagemgr/test_model.py +++ b/tests/st/func/lineagemgr/test_model.py @@ -97,8 +97,7 @@ LINEAGE_FILTRATION_RUN1 = { 'test_dataset_count': 1024, 'user_defined': { 'info': 'info1', - 'version': 'v1', - 'eval_version': 'version2' + 'version': 'v1' }, 'network': 'ResNet', 'optimizer': 'Momentum', @@ -124,7 +123,11 @@ LINEAGE_FILTRATION_RUN2 = { 'train_dataset_count': 1024, 'test_dataset_path': None, 'test_dataset_count': 1024, - 'user_defined': {}, + 'user_defined': { + 'info': 'info1', + 'version': 'v1', + 'eval_version': 'version2' + }, 'network': "ResNet", 'optimizer': "Momentum", 'learning_rate': 0.12, diff --git a/tests/ut/lineagemgr/test_model.py b/tests/ut/lineagemgr/test_model.py index fecf37b7..836d680f 100644 --- a/tests/ut/lineagemgr/test_model.py +++ b/tests/ut/lineagemgr/test_model.py @@ -20,7 +20,8 @@ from mindinsight.lineagemgr.model import filter_summary_lineage, get_flattened_l from mindinsight.lineagemgr.common.exceptions.exceptions import LineageSummaryParseException, \ LineageQuerierParamException, LineageQuerySummaryDataError, LineageSearchConditionParamError, LineageParamTypeError from mindinsight.lineagemgr.common.path_parser import SummaryPathParser -from ...st.func.lineagemgr.test_model import LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 +from tests.st.func.lineagemgr.test_model import LINEAGE_FILTRATION_EXCEPT_RUN, \ + LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 class TestFilterAPI(TestCase): @@ -84,9 +85,9 @@ class TestFilterAPI(TestCase): def test_get_lineage_table(self, mock_filter_summary_lineage): """Test get_flattened_lineage with valid param.""" mock_data = { - 'object': [LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2] + 'object': [LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2] } mock_data_manager = MagicMock() mock_filter_summary_lineage.return_value = mock_data result = get_flattened_lineage(mock_data_manager) - assert result.get('[U]info') == ['info1', None] + assert result.get('[U]info') == [None, 'info1', 'info1']