| @@ -21,8 +21,7 @@ Usage: | |||
| pytest lineagemgr | |||
| """ | |||
| import os | |||
| from unittest import TestCase, mock | |||
| import numpy as np | |||
| from unittest import TestCase | |||
| import pytest | |||
| from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage | |||
| @@ -32,12 +31,6 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||
| from mindinsight.lineagemgr.model import get_flattened_lineage | |||
| from mindspore.application.model_zoo.resnet import ResNet | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.dataset.engine import MindDataset | |||
| from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.callback import RunContext | |||
| from ....utils.lineage_writer.model_lineage import AnalyzeObject, TrainLineage | |||
| from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | |||
| from ....ut.lineagemgr.querier import event_data | |||
| @@ -825,44 +818,15 @@ class TestModelApi(TestCase): | |||
| search_condition | |||
| ) | |||
| class TestLineageTable: | |||
| """Test lineage table .""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Setup method""" | |||
| cls.run_context = dict( | |||
| train_network=ResNet(), | |||
| loss_fn=SoftmaxCrossEntropyWithLogits(), | |||
| net_outputs=Tensor(np.array([0.03])), | |||
| optimizer=Momentum(Tensor(0.12)), | |||
| train_dataset=MindDataset(dataset_size=32), | |||
| epoch_num=10, | |||
| cur_step_num=320, | |||
| parallel_mode="stand_alone", | |||
| device_number=2, | |||
| batch_num=32 | |||
| ) | |||
| cls.user_defined_info = {"info": "info1", "version": "v1"} | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascned_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| @mock.patch.object(AnalyzeObject, 'get_file_size') | |||
| def test_training_end(self): | |||
| def test_get_flattened_lineage(self): | |||
| """Test the function of get_flattened_lineage""" | |||
| train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) | |||
| train_callback.initial_learning_rate = 0.12 | |||
| train_callback.begin(RunContext(self.run_context)) | |||
| train_callback.end(RunContext(self.run_context)) | |||
| summary_base_dir = SUMMARY_DIR | |||
| datamanager = data_manager.DataManager(summary_base_dir) | |||
| datamanager = data_manager.DataManager(SUMMARY_DIR) | |||
| datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||
| datamanager.start_load_data().join() | |||