Browse Source

fix the problem of not counting the ST case

tags/v1.0.0
shenghong 5 years ago
parent
commit
0d495c48a7
1 changed files with 4 additions and 40 deletions
  1. +4
    -40
      tests/st/func/lineagemgr/test_model.py

+ 4
- 40
tests/st/func/lineagemgr/test_model.py View File

@@ -21,8 +21,7 @@ Usage:
pytest lineagemgr
"""
import os
from unittest import TestCase, mock
import numpy as np
from unittest import TestCase
import pytest

from mindinsight.lineagemgr.model import filter_summary_lineage, get_summary_lineage
@@ -32,12 +31,6 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.model import get_flattened_lineage
from mindspore.application.model_zoo.resnet import ResNet
from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import RunContext
from ....utils.lineage_writer.model_lineage import AnalyzeObject, TrainLineage

from .conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from ....ut.lineagemgr.querier import event_data
@@ -825,44 +818,15 @@ class TestModelApi(TestCase):
search_condition
)


class TestLineageTable:
"""Test lineage table ."""

@classmethod
def setup_class(cls):
"""Setup method"""
cls.run_context = dict(
train_network=ResNet(),
loss_fn=SoftmaxCrossEntropyWithLogits(),
net_outputs=Tensor(np.array([0.03])),
optimizer=Momentum(Tensor(0.12)),
train_dataset=MindDataset(dataset_size=32),
epoch_num=10,
cur_step_num=320,
parallel_mode="stand_alone",
device_number=2,
batch_num=32
)
cls.user_defined_info = {"info": "info1", "version": "v1"}

@pytest.mark.scene_train(2)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascned_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
@mock.patch.object(AnalyzeObject, 'get_file_size')
def test_training_end(self):
def test_get_flattened_lineage(self):
"""Test the function of get_flattened_lineage"""
train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)

train_callback.initial_learning_rate = 0.12
train_callback.begin(RunContext(self.run_context))
train_callback.end(RunContext(self.run_context))

summary_base_dir = SUMMARY_DIR
datamanager = data_manager.DataManager(summary_base_dir)
datamanager = data_manager.DataManager(SUMMARY_DIR)
datamanager.register_brief_cache_item_updater(LineageCacheItemUpdater())
datamanager.start_load_data().join()



Loading…
Cancel
Save