Browse Source

fix the problem of not counting the ST case

tags/v1.0.0
shenghong 5 years ago
parent
commit
0d495c48a7
1 changed files with 4 additions and 40 deletions
  1. +4
    -40
      tests/st/func/lineagemgr/test_model.py

+ 4
- 40
tests/st/func/lineagemgr/test_model.py View File

@@ -21,8 +21,7 @@ Usage:
pytest lineagemgr pytest lineagemgr
""" """
import os import os
from unittest import TestCase, mock
import numpy as np
from unittest import TestCase
import pytest import pytest


from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage
@@ -32,12 +31,6 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.model import get_flattened_lineage from mindinsight.lineagemgr.model import get_flattened_lineage
from mindspore.application.model_zoo.resnet import ResNet
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import RunContext
from ....utils.lineage_writer.model_lineage import AnalyzeObject, TrainLineage


from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from ....ut.lineagemgr.querier import event_data from ....ut.lineagemgr.querier import event_data
@@ -825,44 +818,15 @@ class TestModelApi(TestCase):
search_condition search_condition
) )



class TestLineageTable:
"""Test lineage table ."""

@classmethod
def setup_class(cls):
"""Setup method"""
cls.run_context = dict(
train_network=ResNet(),
loss_fn=SoftmaxCrossEntropyWithLogits(),
net_outputs=Tensor(np.array([0.03])),
optimizer=Momentum(Tensor(0.12)),
train_dataset=MindDataset(dataset_size=32),
epoch_num=10,
cur_step_num=320,
parallel_mode="stand_alone",
device_number=2,
batch_num=32
)
cls.user_defined_info = {"info": "info1", "version": "v1"}

@pytest.mark.scene_train(2)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascned_training @pytest.mark.platform_x86_ascned_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_single @pytest.mark.env_single
@mock.patch.object(AnalyzeObject, 'get_file_size')
def test_training_end(self):
def test_get_flattened_lineage(self):
"""Test the function of get_flattened_lineage""" """Test the function of get_flattened_lineage"""
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)

train_callback.initial_learning_rate = 0.12
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))

summary_base_dir = SUMMARY_DIR
datamanager = data_manager.DataManager(summary_base_dir)
datamanager = data_manager.DataManager(SUMMARY_DIR)
datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater()) datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater())
datamanager.start_load_data().join() datamanager.start_load_data().join()




Loading…
Cancel
Save