Browse Source

add some ut for list explain directories and reduct code function length

tags/v1.1.0
ougongchang 5 years ago
parent
commit
b5e21cf4ab
4 changed files with 96 additions and 9 deletions
  1. +6
    -1
      mindinsight/datavisual/data_transform/loader_generators/data_loader_generator.py
  2. +13
    -4
      mindinsight/datavisual/data_transform/summary_watcher.py
  3. +20
    -0
      tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py
  4. +57
    -4
      tests/ut/datavisual/data_transform/test_summary_watcher.py

+ 6
- 1
mindinsight/datavisual/data_transform/loader_generators/data_loader_generator.py View File

@@ -19,6 +19,7 @@ This module generate loaders from summary logs.
"""
import os
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
@@ -26,6 +27,7 @@ from mindinsight.datavisual.data_transform.loader_generators.loader_struct impor
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import LoaderGenerator
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import PathNotExistError


class DataLoaderGenerator(LoaderGenerator):
@@ -241,6 +243,9 @@ class DataLoaderGenerator(LoaderGenerator):

"""
relative_path = self._get_relative_path_from_train_id(train_id)
loader = self._generate_loader_by_relative_path(relative_path)
try:
loader = self._generate_loader_by_relative_path(relative_path)
except PathNotExistError as ex:
raise TrainJobNotExistError(str(ex))

return loader

+ 13
- 4
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -230,10 +230,9 @@ class SummaryWatcher:
elif entry.is_dir():
if list_explain:
return
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name)
is_valid_profiler_dir, profiler_type = self._is_valid_profiler_directory(full_dir_path)
if profiler_pattern is None or not is_valid_profiler_dir:
profiler_type, is_find = self._find_profiler_dir(entry, summary_base_dir, relative_path)
if not is_find:
return
profiler = {
@@ -248,6 +247,16 @@ class SummaryWatcher:
else:
summary_dict[relative_path] = _new_entry(ctime, mtime, profiler)
def _find_profiler_dir(self, entry, summary_base_dir, relative_path):
"""Find profiler dir by the given relative path."""
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name)
is_valid_profiler_dir, profiler_type = self._is_valid_profiler_directory(full_dir_path)
if profiler_pattern is None or not is_valid_profiler_dir:
return profiler_type, False
return profiler_type, True
def _is_valid_pattern_result(self, summary_pattern, pb_pattern, list_explain, entry):
"""Check the pattern result is valid."""
if summary_pattern is None and pb_pattern is None:


+ 20
- 0
tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py View File

@@ -134,6 +134,26 @@ class TestPlugins:
assert response['error_msg'] == "Invalid parameter value. The value of " \
"manual_update must be 'false' or 'true'."

@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("manual_update", [True, False])
def test_manual_update_with_train_job_not_exit(self, client, manual_update):
"""Test passing manual_update with special character, wrong value, and wrong type."""
params = dict(train_id='./123', manual_update=manual_update)
url = get_url(BASE_URL, params)

response = client.get(url)
assert response.status_code == 400

response = response.get_json()
assert response['error_code'] == '50545005'
assert 'Train job is not exist' in response['error_msg']

@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu


+ 57
- 4
tests/ut/datavisual/data_transform/test_summary_watcher.py View File

@@ -27,6 +27,8 @@ import tempfile

from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher

import pytest


def gen_directories_and_files(summary_base_dir, file_count, directory_count):
"""Generate directories and files for test."""
@@ -52,12 +54,36 @@ def gen_directories_and_files(summary_base_dir, file_count, directory_count):
shutil.copytree(os.path.join(summary_base_dir, 'run'), os.path.join(summary_base_dir, f'run{index}'))


def gen_explain_directories_and_files(summary_base_dir, relative_path):
"""Generate XAI directories and files."""
end_time = datetime.datetime.now()
start_time = end_time - datetime.timedelta(days=10)
start_ts = int(start_time.timestamp())
end_ts = int(end_time.timestamp())
os.mkdir(os.path.join(summary_base_dir, relative_path))
summary = os.path.join(summary_base_dir,
relative_path,
f'prefix.summary.{random.randint(start_ts, end_ts)}._explain')
with open(summary, 'w'):
pass


class TestSummaryWatcher:
"""Test summary watcher."""
base_dir = ''

def setup_class(self):
"""Mock common environment for graph unittest."""
self.base_dir = tempfile.mkdtemp()

def teardown_class(self):
"""Delete temp files."""
if os.path.exists(self.base_dir):
shutil.rmtree(self.base_dir)

def test_list_summary_directories_with_overall_on(self):
"""Test list_summary_directories method success."""
summary_base_dir = tempfile.mkdtemp()
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
file_count = 10
directory_count = 10
gen_directories_and_files(summary_base_dir, file_count, directory_count)
@@ -70,7 +96,7 @@ class TestSummaryWatcher:

def test_list_summary_directories_by_pagination(self):
"""Test list_summary_directories method success."""
summary_base_dir = tempfile.mkdtemp()
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
file_count = 10
directory_count = 10
gen_directories_and_files(summary_base_dir, file_count, directory_count)
@@ -90,7 +116,7 @@ class TestSummaryWatcher:

def test_is_summary_directory(self):
"""Test is_summary_directory method success."""
summary_base_dir = tempfile.mkdtemp()
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
file_count = 1
directory_count = 1
gen_directories_and_files(summary_base_dir, file_count, directory_count)
@@ -104,7 +130,7 @@ class TestSummaryWatcher:

def test_list_summaries(self):
"""Test list_summaries method success."""
summary_base_dir = tempfile.mkdtemp()
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
file_count = 10
directory_count = 1
gen_directories_and_files(summary_base_dir, file_count, directory_count)
@@ -115,3 +141,30 @@ class TestSummaryWatcher:
summaries = summary_watcher.list_summaries(summary_base_dir, './\x00')
assert not summaries
shutil.rmtree(summary_base_dir)

@pytest.mark.parametrize("job_count", [0, 1, 3])
def test_list_explain_directories_only_base_dir(self, job_count):
"""Test list explain directories with summary base dir, and no test offset and limit."""
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
if job_count:
for i in range(job_count):
gen_explain_directories_and_files(summary_base_dir, f'run{i}')
summary_watcher = SummaryWatcher()
total, _ = summary_watcher.list_explain_directories(summary_base_dir)
assert total == job_count
shutil.rmtree(summary_base_dir)

@pytest.mark.parametrize("offset, limit", [(1, 1), (2, 2), (3, 3)])
def test_list_explain_dir_with_offset_limit(self, offset, limit):
"""Test list explain dir with offset and limit."""
summary_base_dir = tempfile.mkdtemp(dir=self.base_dir)
gen_directories_and_files(summary_base_dir, file_count=1, directory_count=3)
for i in range(10):
gen_explain_directories_and_files(summary_base_dir, f'run_{i}')

summary_watcher = SummaryWatcher()
_, result = summary_watcher.list_explain_directories(summary_base_dir, offset, limit)
if offset == 3:
assert len(result) == 1
else:
assert len(result) == limit

Loading…
Cancel
Save