| @@ -0,0 +1,26 @@ | |||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||
| If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md | |||
| --> | |||
| **What type of PR is this?** | |||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||
| > | |||
| > /kind bug | |||
| > /kind task | |||
| > /kind feature | |||
| **What does this PR do / why do we need it**: | |||
| **Which issue(s) this PR fixes**: | |||
| <!-- | |||
| *Automatically closes linked issue when PR is merged. | |||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||
| --> | |||
| Fixes # | |||
| **Special notes for your reviewers**: | |||
| @@ -0,0 +1,19 @@ | |||
| --- | |||
| name: RFC | |||
| about: Use this template for the new feature or enhancement | |||
| labels: kind/feature or kind/enhancement | |||
| --- | |||
| ## Background | |||
| - Describe the status of the problem you wish to solve | |||
| - Attach the relevant issue if have | |||
| ## Introduction | |||
| - Describe the general solution, design and/or pseudo-code | |||
| ## Trail | |||
| | No. | Task Description | Related Issue(URL) | | |||
| | --- | ---------------- | ------------------ | | |||
| | 1 | | | | |||
| | 2 | | | | |||
| @@ -0,0 +1,43 @@ | |||
| --- | |||
| name: Bug Report | |||
| about: Use this template for reporting a bug | |||
| labels: kind/bug | |||
| --- | |||
| <!-- Thanks for sending an issue! Here are some tips for you: | |||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||
| --> | |||
| ## Environment | |||
| ### Hardware Environment(`Ascend`/`GPU`/`CPU`): | |||
| > Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||
| > | |||
| > `/device ascend`</br> | |||
| > `/device gpu`</br> | |||
| > `/device cpu`</br> | |||
| ### Software Environment: | |||
| - **MindSpore version (source or binary)**: | |||
| - **Python version (e.g., Python 3.7.5)**: | |||
| - **OS platform and distribution (e.g., Linux Ubuntu 16.04)**: | |||
| - **GCC/Compiler version (if compiled from source)**: | |||
| ## Describe the current behavior | |||
| ## Describe the expected behavior | |||
| ## Steps to reproduce the issue | |||
| 1. | |||
| 2. | |||
| 3. | |||
| ## Related log / screenshot | |||
| ## Special notes for this issue | |||
| @@ -0,0 +1,19 @@ | |||
| --- | |||
| name: Task | |||
| about: Use this template for task tracking | |||
| labels: kind/task | |||
| --- | |||
| ## Task Description | |||
| ## Task Goal | |||
| ## Sub Task | |||
| | No. | Task Description | Issue ID | | |||
| | --- | ---------------- | -------- | | |||
| | 1 | | | | |||
| | 2 | | | | |||
| @@ -0,0 +1,24 @@ | |||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||
| --> | |||
| **What type of PR is this?** | |||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||
| > | |||
| > `/kind bug`</br> | |||
| > `/kind task`</br> | |||
| > `/kind feature`</br> | |||
| **What does this PR do / why do we need it**: | |||
| **Which issue(s) this PR fixes**: | |||
| <!-- | |||
| *Automatically closes linked issue when PR is merged. | |||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||
| --> | |||
| Fixes # | |||
| **Special notes for your reviewers**: | |||
| @@ -18,19 +18,21 @@ This file is used to define the basic graph. | |||
| import copy | |||
| import time | |||
| from enum import Enum | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common import exceptions | |||
| from .node import NodeTypeEnum | |||
| from .node import Node | |||
| class EdgeTypeEnum: | |||
| class EdgeTypeEnum(Enum): | |||
| """Node edge type enum.""" | |||
| control = 'control' | |||
| data = 'data' | |||
| CONTROL = 'control' | |||
| DATA = 'data' | |||
| class DataTypeEnum: | |||
| class DataTypeEnum(Enum): | |||
| """Data type enum.""" | |||
| DT_TENSOR = 13 | |||
| @@ -292,70 +294,65 @@ class Graph: | |||
| output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value | |||
| node.update_output({dst_name: output_attr}) | |||
| def _calc_polymeric_input_output(self): | |||
| def _update_polymeric_input_output(self): | |||
| """Calc polymeric input and output after build polymeric node.""" | |||
| for name, node in self._normal_nodes.items(): | |||
| polymeric_input = {} | |||
| for src_name in node.input: | |||
| src_node = self._polymeric_nodes.get(src_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| src_name = src_name if not src_node else src_node.polymeric_scope_name | |||
| output_name = self._calc_dummy_node_name(name, src_name) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| continue | |||
| if not src_node: | |||
| continue | |||
| if not node.name_scope and src_node.name_scope: | |||
| # if current node is in first layer, and the src node is not in | |||
| # the first layer, the src node will not be the polymeric input of current node. | |||
| continue | |||
| if node.name_scope == src_node.name_scope \ | |||
| or node.name_scope.startswith(src_node.name_scope): | |||
| polymeric_input.update( | |||
| {src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| for node in self._normal_nodes.values(): | |||
| polymeric_input = self._calc_polymeric_attr(node, 'input') | |||
| node.update_polymeric_input(polymeric_input) | |||
| polymeric_output = {} | |||
| for dst_name in node.output: | |||
| dst_node = self._polymeric_nodes.get(dst_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name | |||
| output_name = self._calc_dummy_node_name(name, dst_name) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| continue | |||
| if not dst_node: | |||
| continue | |||
| if not node.name_scope and dst_node.name_scope: | |||
| continue | |||
| if node.name_scope == dst_node.name_scope \ | |||
| or node.name_scope.startswith(dst_node.name_scope): | |||
| polymeric_output.update( | |||
| {dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_output = self._calc_polymeric_attr(node, 'output') | |||
| node.update_polymeric_output(polymeric_output) | |||
| for name, node in self._polymeric_nodes.items(): | |||
| polymeric_input = {} | |||
| for src_name in node.input: | |||
| output_name = self._calc_dummy_node_name(name, src_name) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| node.update_polymeric_input(polymeric_input) | |||
| polymeric_output = {} | |||
| for dst_name in node.output: | |||
| polymeric_output = {} | |||
| output_name = self._calc_dummy_node_name(name, dst_name) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| node.update_polymeric_output(polymeric_output) | |||
| def _calc_polymeric_attr(self, node, attr): | |||
| """ | |||
| Calc polymeric input or polymeric output after build polymeric node. | |||
| Args: | |||
| node (Node): Computes the polymeric input for a given node. | |||
| attr (str): The polymeric attr, optional value is `input` or `output`. | |||
| Returns: | |||
| dict, return polymeric input or polymeric output of the given node. | |||
| """ | |||
| polymeric_attr = {} | |||
| for node_name in getattr(node, attr): | |||
| polymeric_node = self._polymeric_nodes.get(node_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name | |||
| dummy_node_name = self._calc_dummy_node_name(node.name, node_name) | |||
| polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| continue | |||
| if not polymeric_node: | |||
| continue | |||
| if not node.name_scope and polymeric_node.name_scope: | |||
| # If current node is in top-level layer, and the polymeric_node node is not in | |||
| # the top-level layer, the polymeric node will not be the polymeric input | |||
| # or polymeric output of current node. | |||
| continue | |||
| if node.name_scope == polymeric_node.name_scope \ | |||
| or node.name_scope.startswith(polymeric_node.name_scope + '/'): | |||
| polymeric_attr.update( | |||
| {polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| return polymeric_attr | |||
| def _calc_dummy_node_name(self, current_node_name, other_node_name): | |||
| """ | |||
| Calc dummy node name. | |||
| @@ -39,7 +39,7 @@ class MSGraph(Graph): | |||
| self._build_leaf_nodes(graph_proto) | |||
| self._build_polymeric_nodes() | |||
| self._build_name_scope_nodes() | |||
| self._calc_polymeric_input_output() | |||
| self._update_polymeric_input_output() | |||
| logger.info("Build graph end, normal node count: %s, polymeric node " | |||
| "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) | |||
| @@ -90,9 +90,9 @@ class MSGraph(Graph): | |||
| node_name = leaf_node_id_map_name[node_def.name] | |||
| node = self._leaf_nodes[node_name] | |||
| for input_def in node_def.input: | |||
| edge_type = EdgeTypeEnum.data | |||
| edge_type = EdgeTypeEnum.DATA.value | |||
| if input_def.type == "CONTROL_EDGE": | |||
| edge_type = EdgeTypeEnum.control | |||
| edge_type = EdgeTypeEnum.CONTROL.value | |||
| if const_nodes_map.get(input_def.name): | |||
| const_node = copy.deepcopy(const_nodes_map[input_def.name]) | |||
| @@ -218,7 +218,7 @@ class MSGraph(Graph): | |||
| node = Node(name=const.key, node_id=const_node_id) | |||
| node.node_type = NodeTypeEnum.CONST.value | |||
| node.update_attr({const.key: str(const.value)}) | |||
| if const.value.dtype == DataTypeEnum.DT_TENSOR: | |||
| if const.value.dtype == DataTypeEnum.DT_TENSOR.value: | |||
| shape = [] | |||
| for dim in const.value.tensor_val.dims: | |||
| shape.append(dim) | |||
| @@ -172,7 +172,7 @@ class Node: | |||
| Args: | |||
| polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: | |||
| {'edge_type': EdgeTypeEnum.data}}). | |||
| {'edge_type': EdgeTypeEnum.DATA.value}}). | |||
| """ | |||
| self._polymeric_output.update(polymeric_output) | |||
| @@ -168,7 +168,7 @@ class TrainLineage(Callback): | |||
| train_lineage = AnalyzeObject.get_network_args( | |||
| run_context_args, train_lineage | |||
| ) | |||
| train_dataset = run_context_args.get('train_dataset') | |||
| callbacks = run_context_args.get('list_callback') | |||
| list_callback = getattr(callbacks, '_callbacks', []) | |||
| @@ -601,7 +601,7 @@ class AnalyzeObject: | |||
| loss = None | |||
| else: | |||
| loss = run_context_args.get('net_outputs') | |||
| if loss: | |||
| log.info('Calculating loss...') | |||
| loss_numpy = loss.asnumpy() | |||
| @@ -610,7 +610,7 @@ class AnalyzeObject: | |||
| train_lineage[Metadata.loss] = loss | |||
| else: | |||
| train_lineage[Metadata.loss] = None | |||
| # Analyze classname of optimizer, loss function and training network. | |||
| train_lineage[Metadata.optimizer] = type(optimizer).__name__ \ | |||
| if optimizer else None | |||
| @@ -18,13 +18,10 @@ Description: This file is used for some common util. | |||
| import os | |||
| import shutil | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from flask import Response | |||
| from tests.st.func.datavisual import constants | |||
| from tests.st.func.datavisual.utils.log_operations import LogOperations | |||
| from tests.st.func.datavisual.utils.utils import check_loading_done | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | |||
| @@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | |||
| from mindinsight.datavisual.utils import tools | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done | |||
| from . import constants | |||
| from . import globals as gbl | |||
| summaries_metadata = None | |||
| mock_data_manager = None | |||
| summary_base_dir = constants.SUMMARY_BASE_DIR | |||
| @@ -55,17 +57,21 @@ def init_summary_logs(): | |||
| os.mkdir(summary_base_dir, mode=mode) | |||
| global summaries_metadata, mock_data_manager | |||
| log_operations = LogOperations() | |||
| summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST) | |||
| summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, | |||
| constants.SUMMARY_DIR_PREFIX) | |||
| mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| mock_data_manager.start_load_data(reload_interval=0) | |||
| check_loading_done(mock_data_manager) | |||
| summaries_metadata.update(log_operations.create_summary_logs( | |||
| summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST)) | |||
| summaries_metadata.update(log_operations.create_multiple_logs( | |||
| summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM)) | |||
| summaries_metadata.update(log_operations.create_reservoir_log( | |||
| summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) | |||
| summaries_metadata.update( | |||
| log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, | |||
| constants.SUMMARY_DIR_NUM_FIRST)) | |||
| summaries_metadata.update( | |||
| log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME, | |||
| constants.MULTIPLE_LOG_NUM)) | |||
| summaries_metadata.update( | |||
| log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME, | |||
| constants.RESERVOIR_STEP_NUM)) | |||
| mock_data_manager.start_load_data(reload_interval=0) | |||
| # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. | |||
| @@ -73,7 +79,7 @@ def init_summary_logs(): | |||
| # Maximum number of loads is `MAX_DATA_LOADER_SIZE`. | |||
| for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): | |||
| summaries_metadata.pop("./%s%d" % (constants.SUMMARY_PREFIX, i)) | |||
| summaries_metadata.pop("./%s%d" % (constants.SUMMARY_DIR_PREFIX, i)) | |||
| yield | |||
| finally: | |||
| @@ -16,7 +16,7 @@ | |||
| import tempfile | |||
| SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name | |||
| SUMMARY_PREFIX = "summary" | |||
| SUMMARY_DIR_PREFIX = "summary" | |||
| SUMMARY_DIR_NUM_FIRST = 5 | |||
| SUMMARY_DIR_NUM_SECOND = 11 | |||
| @@ -19,11 +19,11 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' | |||
| @@ -33,12 +33,6 @@ class TestQueryNodes: | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -65,4 +59,5 @@ class TestQueryNodes: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -19,12 +19,11 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' | |||
| @@ -34,12 +33,6 @@ class TestQuerySingleNode: | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -59,4 +52,5 @@ class TestQuerySingleNode: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -19,25 +19,20 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' | |||
| class TestSearchNodes: | |||
| """Test search nodes restful APIs.""" | |||
| """Test searching nodes restful APIs.""" | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -58,4 +53,5 @@ class TestSearchNodes: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -20,13 +20,13 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID | |||
| BASE_URL = '/v1/mindinsight/datavisual/image/metadata' | |||
| @@ -20,11 +20,11 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_image_tensor_from_bytes, get_url | |||
| from .. import globals as gbl | |||
| BASE_URL = '/v1/mindinsight/datavisual/image/single-image' | |||
| @@ -19,11 +19,12 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' | |||
| @@ -20,11 +20,11 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| BASE_URL = '/v1/mindinsight/datavisual/plugins' | |||
| @@ -19,11 +19,12 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| BASE_URL = '/v1/mindinsight/datavisual/single-job' | |||
| @@ -20,8 +20,8 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.constants import SUMMARY_DIR_NUM | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from ..constants import SUMMARY_DIR_NUM | |||
| from .....utils.tools import get_url | |||
| BASE_URL = '/v1/mindinsight/datavisual/train-jobs' | |||
| @@ -1,79 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Log generator for graph.""" | |||
| import json | |||
| import os | |||
| import time | |||
| from google.protobuf import json_format | |||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| class GraphLogGenerator(LogGenerator): | |||
| """ | |||
| Log generator for graph. | |||
| This is a log generator writing graph. User can use it to generate fake | |||
| summary logs about graph. | |||
| """ | |||
| def generate_log(self, file_path, graph_dict): | |||
| """ | |||
| Generate log for external calls. | |||
| Args: | |||
| file_path (str): Path to write logs. | |||
| graph_dict (dict): A dict consists of graph node information. | |||
| Returns: | |||
| dict, generated scalar metadata. | |||
| """ | |||
| graph_event = self.generate_event(dict(graph=graph_dict)) | |||
| self._write_log_from_event(file_path, graph_event) | |||
| return graph_dict | |||
| def generate_event(self, values): | |||
| """ | |||
| Method for generating graph event. | |||
| Args: | |||
| values (dict): Graph values. e.g. {'graph': graph_dict}. | |||
| Returns: | |||
| summary_pb2.Event. | |||
| """ | |||
| graph_json = { | |||
| 'wall_time': time.time(), | |||
| 'graph_def': values.get('graph'), | |||
| } | |||
| graph_event = json_format.Parse(json.dumps(graph_json), summary_pb2.Event()) | |||
| return graph_event | |||
| if __name__ == "__main__": | |||
| graph_log_generator = GraphLogGenerator() | |||
| test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") | |||
| with open(graph_base_path, 'r') as load_f: | |||
| graph = json.load(load_f) | |||
| graph_log_generator.generate_log(test_file_name, graph) | |||
| @@ -20,11 +20,11 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | |||
| PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | |||
| METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | |||
| @@ -20,11 +20,11 @@ Usage: | |||
| """ | |||
| import pytest | |||
| from tests.st.func.datavisual.utils import globals as gbl | |||
| from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_image_tensor_from_bytes, get_url | |||
| from .. import globals as gbl | |||
| TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | |||
| PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | |||
| METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | |||
| @@ -26,11 +26,101 @@ from unittest import TestCase | |||
| import pytest | |||
| from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \ | |||
| LineageSearchConditionParamError, LineageFileNotFoundError | |||
| from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH | |||
| from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError, | |||
| LineageParamTypeError, LineageParamValueError, | |||
| LineageSearchConditionParamError) | |||
| from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | |||
| LINEAGE_INFO_RUN1 = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'epoch': 14, | |||
| 'parallel_mode': 'stand_alone', | |||
| 'device_num': 2, | |||
| 'batch_size': 32 | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'ResNet' | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_size': 731 | |||
| }, | |||
| 'valid_dataset': { | |||
| 'valid_dataset_size': 10240 | |||
| }, | |||
| 'model': { | |||
| 'path': '{"ckpt": "' | |||
| + BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}', | |||
| 'size': 64 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH | |||
| } | |||
| LINEAGE_FILTRATION_EXCEPT_RUN = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 1024, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 10, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 64, | |||
| 'metric': {}, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| } | |||
| LINEAGE_FILTRATION_RUN1 = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 731, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': None, | |||
| 'model_size': 64, | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| } | |||
| LINEAGE_FILTRATION_RUN2 = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||
| 'loss_function': None, | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': None, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': None, | |||
| 'optimizer': None, | |||
| 'learning_rate': None, | |||
| 'epoch': None, | |||
| 'batch_size': None, | |||
| 'loss': None, | |||
| 'model_size': None, | |||
| 'metric': { | |||
| 'accuracy': 2.7800000000000002 | |||
| }, | |||
| 'dataset_graph': {}, | |||
| 'dataset_mark': 3 | |||
| } | |||
| @pytest.mark.usefixtures("create_summary_dir") | |||
| @@ -67,36 +157,7 @@ class TestModelApi(TestCase): | |||
| total_res = get_summary_lineage(SUMMARY_DIR) | |||
| partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters']) | |||
| partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm']) | |||
| expect_total_res = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'epoch': 14, | |||
| 'parallel_mode': 'stand_alone', | |||
| 'device_num': 2, | |||
| 'batch_size': 32 | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'ResNet' | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_size': 731 | |||
| }, | |||
| 'valid_dataset': { | |||
| 'valid_dataset_size': 10240 | |||
| }, | |||
| 'model': { | |||
| 'path': '{"ckpt": "' | |||
| + BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}', | |||
| 'size': 64 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH | |||
| } | |||
| expect_total_res = LINEAGE_INFO_RUN1 | |||
| expect_partial_res1 = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'hyper_parameters': { | |||
| @@ -139,7 +200,7 @@ class TestModelApi(TestCase): | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_get_summary_lineage_exception(self): | |||
| def test_get_summary_lineage_exception_1(self): | |||
| """Test the interface of get_summary_lineage with exception.""" | |||
| # summary path does not exist | |||
| self.assertRaisesRegex( | |||
| @@ -183,6 +244,14 @@ class TestModelApi(TestCase): | |||
| keys=None | |||
| ) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_get_summary_lineage_exception_2(self): | |||
| """Test the interface of get_summary_lineage with exception.""" | |||
| # keys is invalid | |||
| self.assertRaisesRegex( | |||
| LineageParamValueError, | |||
| @@ -250,64 +319,9 @@ class TestModelApi(TestCase): | |||
| """Test the interface of filter_summary_lineage.""" | |||
| expect_result = { | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 1024, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 10, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 64, | |||
| 'metric': {}, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 731, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': None, | |||
| 'model_size': 64, | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||
| 'loss_function': None, | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': None, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': None, | |||
| 'optimizer': None, | |||
| 'learning_rate': None, | |||
| 'epoch': None, | |||
| 'batch_size': None, | |||
| 'loss': None, | |||
| 'model_size': None, | |||
| 'metric': { | |||
| 'accuracy': 2.7800000000000002 | |||
| }, | |||
| 'dataset_graph': {}, | |||
| 'dataset_mark': 3 | |||
| } | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| LINEAGE_FILTRATION_RUN2 | |||
| ], | |||
| 'count': 3 | |||
| } | |||
| @@ -357,46 +371,8 @@ class TestModelApi(TestCase): | |||
| } | |||
| expect_result = { | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||
| 'loss_function': None, | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': None, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': None, | |||
| 'optimizer': None, | |||
| 'learning_rate': None, | |||
| 'epoch': None, | |||
| 'batch_size': None, | |||
| 'loss': None, | |||
| 'model_size': None, | |||
| 'metric': { | |||
| 'accuracy': 2.7800000000000002 | |||
| }, | |||
| 'dataset_graph': {}, | |||
| 'dataset_mark': 3 | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 731, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': None, | |||
| 'model_size': 64, | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| } | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| ], | |||
| 'count': 2 | |||
| } | |||
| @@ -432,46 +408,8 @@ class TestModelApi(TestCase): | |||
| } | |||
| expect_result = { | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||
| 'loss_function': None, | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': None, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': None, | |||
| 'optimizer': None, | |||
| 'learning_rate': None, | |||
| 'epoch': None, | |||
| 'batch_size': None, | |||
| 'loss': None, | |||
| 'model_size': None, | |||
| 'metric': { | |||
| 'accuracy': 2.7800000000000002 | |||
| }, | |||
| 'dataset_graph': {}, | |||
| 'dataset_mark': 3 | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 731, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': None, | |||
| 'model_size': 64, | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| } | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| ], | |||
| 'count': 2 | |||
| } | |||
| @@ -498,44 +436,8 @@ class TestModelApi(TestCase): | |||
| } | |||
| expect_result = { | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 1024, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 10, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 64, | |||
| 'metric': {}, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 731, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 10240, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': None, | |||
| 'model_size': 64, | |||
| 'metric': { | |||
| 'accuracy': 0.78 | |||
| }, | |||
| 'dataset_graph': DATASET_GRAPH, | |||
| 'dataset_mark': 2 | |||
| } | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| ], | |||
| 'count': 2 | |||
| } | |||
| @@ -674,6 +576,14 @@ class TestModelApi(TestCase): | |||
| search_condition | |||
| ) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_filter_summary_lineage_exception_3(self): | |||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||
| # the condition of offset is invalid | |||
| search_condition = { | |||
| 'offset': 1.0 | |||
| @@ -712,6 +622,14 @@ class TestModelApi(TestCase): | |||
| search_condition | |||
| ) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_filter_summary_lineage_exception_4(self): | |||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||
| # the sorted_type not supported | |||
| search_condition = { | |||
| 'sorted_name': 'summary_dir', | |||
| @@ -753,6 +671,14 @@ class TestModelApi(TestCase): | |||
| search_condition | |||
| ) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_filter_summary_lineage_exception_5(self): | |||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||
| # the summary dir is invalid in search condition | |||
| search_condition = { | |||
| 'summary_dir': { | |||
| @@ -811,7 +737,7 @@ class TestModelApi(TestCase): | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_filter_summary_lineage_exception_3(self): | |||
| def test_filter_summary_lineage_exception_6(self): | |||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||
| # gt > lt | |||
| search_condition1 = { | |||
| @@ -21,7 +21,8 @@ import tempfile | |||
| import pytest | |||
| from .collection.model import mindspore | |||
| from ....utils import mindspore | |||
| from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE | |||
| sys.modules['mindspore'] = mindspore | |||
| @@ -32,52 +33,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run') | |||
| COLLECTION_MODULE = 'TestModelLineage' | |||
| API_MODULE = 'TestModelApi' | |||
| DATASET_GRAPH = { | |||
| 'op_type': 'BatchDataset', | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'num_parallel_workers': None, | |||
| 'drop_remainder': True, | |||
| 'batch_size': 10, | |||
| 'children': [ | |||
| { | |||
| 'op_type': 'MapDataset', | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'num_parallel_workers': None, | |||
| 'input_columns': [ | |||
| 'label' | |||
| ], | |||
| 'output_columns': [ | |||
| None | |||
| ], | |||
| 'operations': [ | |||
| { | |||
| 'tensor_op_module': 'minddata.transforms.c_transforms', | |||
| 'tensor_op_name': 'OneHot', | |||
| 'num_classes': 10 | |||
| } | |||
| ], | |||
| 'children': [ | |||
| { | |||
| 'op_type': 'MnistDataset', | |||
| 'shard_id': None, | |||
| 'num_shards': None, | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData', | |||
| 'num_parallel_workers': None, | |||
| 'shuffle': None, | |||
| 'num_samples': 100, | |||
| 'sampler': { | |||
| 'sampler_module': 'minddata.dataengine.samplers', | |||
| 'sampler_name': 'RandomSampler', | |||
| 'replacement': True, | |||
| 'num_samples': 100 | |||
| }, | |||
| 'children': [] | |||
| } | |||
| ] | |||
| } | |||
| ] | |||
| } | |||
| DATASET_GRAPH = SERIALIZED_PIPELINE | |||
| def get_module_name(nodeid): | |||
| """Get the module name from nodeid.""" | |||
| @@ -14,6 +14,6 @@ | |||
| # ============================================================================ | |||
| """Import the mocked mindspore.""" | |||
| import sys | |||
| from .lineagemgr.collection.model import mindspore | |||
| from ..utils import mindspore | |||
| sys.modules['mindspore'] = mindspore | |||
| @@ -21,14 +21,15 @@ Usage: | |||
| from unittest.mock import patch | |||
| import pytest | |||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||
| from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||
| from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||
| from tests.ut.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager | |||
| from ....utils.log_generators.images_log_generator import ImagesLogGenerator | |||
| from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||
| from ....utils.tools import get_url | |||
| from .conftest import TRAIN_ROUTES | |||
| class TestTrainTask: | |||
| """Test train task api.""" | |||
| @@ -36,9 +37,7 @@ class TestTrainTask: | |||
| _scalar_log_generator = ScalarsLogGenerator() | |||
| _image_log_generator = ImagesLogGenerator() | |||
| @pytest.mark.parametrize( | |||
| "plugin_name", | |||
| ['no_plugin_name', 'not_exist_plugin_name']) | |||
| @pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name']) | |||
| def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): | |||
| """ | |||
| Parsing unavailable plugin name to single train task. | |||
| @@ -21,14 +21,15 @@ Usage: | |||
| from unittest.mock import Mock, patch | |||
| import pytest | |||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||
| from tests.ut.datavisual.utils.utils import get_url | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.processors.graph_processor import GraphProcessor | |||
| from mindinsight.datavisual.processors.images_processor import ImageProcessor | |||
| from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | |||
| from ....utils.tools import get_url | |||
| from .conftest import TRAIN_ROUTES | |||
| class TestTrainVisual: | |||
| """Test Train Visual APIs.""" | |||
| @@ -95,14 +96,7 @@ class TestTrainVisual: | |||
| assert response.status_code == 200 | |||
| response = response.get_json() | |||
| expected_response = { | |||
| "metadatas": [{ | |||
| "height": 224, | |||
| "step": 1, | |||
| "wall_time": 1572058058.1175, | |||
| "width": 448 | |||
| }] | |||
| } | |||
| expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]} | |||
| assert expected_response == response | |||
| def test_single_image_with_params_miss(self, client): | |||
| @@ -254,8 +248,10 @@ class TestTrainVisual: | |||
| @patch.object(GraphProcessor, 'get_nodes') | |||
| def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): | |||
| """Test getting graph nodes successfully.""" | |||
| def mock_get_nodes(name, node_type): | |||
| return dict(name=name, node_type=node_type) | |||
| mock_graph_processor.side_effect = mock_get_nodes | |||
| mock_init = Mock(return_value=None) | |||
| @@ -327,10 +323,7 @@ class TestTrainVisual: | |||
| assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ | |||
| "be greater than or equal to 0." | |||
| @pytest.mark.parametrize( | |||
| "limit", | |||
| [-1, 0, 1001] | |||
| ) | |||
| @pytest.mark.parametrize("limit", [-1, 0, 1001]) | |||
| @patch.object(GraphProcessor, '__init__') | |||
| def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): | |||
| """Test getting graph node names with invalid limit.""" | |||
| @@ -348,14 +341,10 @@ class TestTrainVisual: | |||
| assert results['error_msg'] == "Invalid parameter value. " \ | |||
| "'limit' should in [1, 1000]." | |||
| @pytest.mark.parametrize( | |||
| " offset, limit", | |||
| [(0, 100), (1, 1), (0, 1000)] | |||
| ) | |||
| @pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)]) | |||
| @patch.object(GraphProcessor, '__init__') | |||
| @patch.object(GraphProcessor, 'search_node_names') | |||
| def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, | |||
| offset, limit): | |||
| def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit): | |||
| """ | |||
| Parsing unavailable params to get image metadata. | |||
| @@ -367,8 +356,10 @@ class TestTrainVisual: | |||
| response status code: 200. | |||
| response json: dict, contains search_content, offset, and limit. | |||
| """ | |||
| def mock_search_node_names(search_content, offset, limit): | |||
| return dict(search_content=search_content, offset=int(offset), limit=int(limit)) | |||
| mock_graph_processor.side_effect = mock_search_node_names | |||
| mock_init = Mock(return_value=None) | |||
| @@ -376,15 +367,12 @@ class TestTrainVisual: | |||
| test_train_id = "aaa" | |||
| test_search_content = "bbb" | |||
| params = dict(train_id=test_train_id, search=test_search_content, | |||
| offset=offset, limit=limit) | |||
| params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit) | |||
| url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| results = response.get_json() | |||
| assert results == dict(search_content=test_search_content, | |||
| offset=int(offset), | |||
| limit=int(limit)) | |||
| assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit)) | |||
| def test_graph_search_single_node_with_params_is_wrong(self, client): | |||
| """Test searching graph single node with params is wrong.""" | |||
| @@ -427,8 +415,10 @@ class TestTrainVisual: | |||
| response status code: 200. | |||
| response json: name. | |||
| """ | |||
| def mock_search_single_node(name): | |||
| return name | |||
| mock_graph_processor.side_effect = mock_search_single_node | |||
| mock_init = Mock(return_value=None) | |||
| @@ -20,8 +20,42 @@ from unittest import TestCase, mock | |||
| from flask import Response | |||
| from mindinsight.backend.application import APP | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageQuerySummaryDataError | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError | |||
| LINEAGE_FILTRATION_BASE = { | |||
| 'accuracy': None, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 12, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| } | |||
| LINEAGE_FILTRATION_RUN1 = { | |||
| 'accuracy': 0.78, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 64, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| } | |||
| class TestSearchModel(TestCase): | |||
| @@ -42,39 +76,11 @@ class TestSearchModel(TestCase): | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': base_dir, | |||
| 'accuracy': None, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 12, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| **LINEAGE_FILTRATION_BASE | |||
| }, | |||
| { | |||
| 'summary_dir': os.path.join(base_dir, 'run1'), | |||
| 'accuracy': 0.78, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 64, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| **LINEAGE_FILTRATION_RUN1 | |||
| } | |||
| ], | |||
| 'count': 2 | |||
| @@ -93,39 +99,11 @@ class TestSearchModel(TestCase): | |||
| 'object': [ | |||
| { | |||
| 'summary_dir': './', | |||
| 'accuracy': None, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': None, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 12, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| **LINEAGE_FILTRATION_BASE | |||
| }, | |||
| { | |||
| 'summary_dir': './run1', | |||
| 'accuracy': 0.78, | |||
| 'mae': None, | |||
| 'mse': None, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'train_dataset_path': None, | |||
| 'train_dataset_count': 64, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 64, | |||
| 'network': 'str', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'loss': 0.029999999329447746, | |||
| 'model_size': 128 | |||
| **LINEAGE_FILTRATION_RUN1 | |||
| } | |||
| ], | |||
| 'count': 2 | |||
| @@ -19,15 +19,16 @@ Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| from unittest.mock import patch | |||
| from werkzeug.exceptions import MethodNotAllowed, NotFound | |||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.utils import get_url | |||
| from werkzeug.exceptions import MethodNotAllowed, NotFound | |||
| from mindinsight.datavisual.processors import scalars_processor | |||
| from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | |||
| from ....utils.tools import get_url | |||
| from ...backend.datavisual.conftest import TRAIN_ROUTES | |||
| from ..mock import MockLogger | |||
| class TestErrorHandler: | |||
| """Test train visual api.""" | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test mindinsight.datavisual.data_transform.log_generators.data_loader_generator | |||
| Test mindinsight.datavisual.data_transform.loader_generators.data_loader_generator | |||
| Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| @@ -22,18 +22,19 @@ import datetime | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| from unittest.mock import patch | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| import pytest | |||
| from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ...mock import MockLogger | |||
| class TestDataLoaderGenerator: | |||
| """Test data_loader_generator.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| data_loader_generator.logger = MockLogger | |||
| @@ -88,8 +89,9 @@ class TestDataLoaderGenerator: | |||
| mock_data_loader.return_value = True | |||
| loader_dict = generator.generate_loaders(loader_pool=dict()) | |||
| expected_ids = [summary.get('relative_path') | |||
| for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]] | |||
| expected_ids = [ | |||
| summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:] | |||
| ] | |||
| assert sorted(loader_dict.keys()) == sorted(expected_ids) | |||
| shutil.rmtree(summary_base_dir) | |||
| @@ -23,12 +23,13 @@ import shutil | |||
| import tempfile | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid | |||
| from mindinsight.datavisual.data_transform import data_loader | |||
| from mindinsight.datavisual.data_transform.data_loader import DataLoader | |||
| from ..mock import MockLogger | |||
| class TestDataLoader: | |||
| """Test data_loader.""" | |||
| @@ -37,13 +38,13 @@ class TestDataLoader: | |||
| def setup_class(cls): | |||
| data_loader.logger = MockLogger | |||
| def setup_method(self, method): | |||
| def setup_method(self): | |||
| self._summary_dir = tempfile.mkdtemp() | |||
| if os.path.exists(self._summary_dir): | |||
| shutil.rmtree(self._summary_dir) | |||
| os.mkdir(self._summary_dir) | |||
| def teardown_method(self, method): | |||
| def teardown_method(self): | |||
| if os.path.exists(self._summary_dir): | |||
| shutil.rmtree(self._summary_dir) | |||
| @@ -18,32 +18,29 @@ Function: | |||
| Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| import time | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import time | |||
| from unittest import mock | |||
| from unittest.mock import Mock | |||
| from unittest.mock import patch | |||
| from unittest.mock import Mock, patch | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.utils import check_loading_done | |||
| from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum | |||
| from mindinsight.datavisual.data_transform import data_manager, ms_data_loader | |||
| from mindinsight.datavisual.data_transform.data_loader import DataLoader | |||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | |||
| from mindinsight.datavisual.data_transform.events_data import EventsData | |||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \ | |||
| DataLoaderGenerator | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \ | |||
| MAX_DATA_LOADER_SIZE | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \ | |||
| LoaderStruct | |||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | |||
| from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ....utils.tools import check_loading_done | |||
| from ..mock import MockLogger | |||
| class TestDataManager: | |||
| """Test data_manager.""" | |||
| @@ -101,11 +98,17 @@ class TestDataManager: | |||
| "and loader pool size is '3'." | |||
| shutil.rmtree(summary_base_dir) | |||
| @pytest.mark.parametrize('params', | |||
| [{'reload_interval': '30'}, | |||
| {'reload_interval': -1}, | |||
| {'reload_interval': 30, 'max_threads_count': '20'}, | |||
| {'reload_interval': 30, 'max_threads_count': 0}]) | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'reload_interval': '30' | |||
| }, { | |||
| 'reload_interval': -1 | |||
| }, { | |||
| 'reload_interval': 30, | |||
| 'max_threads_count': '20' | |||
| }, { | |||
| 'reload_interval': 30, | |||
| 'max_threads_count': 0 | |||
| }]) | |||
| def test_start_load_data_with_invalid_params(self, params): | |||
| """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" | |||
| summary_base_dir = tempfile.mkdtemp() | |||
| @@ -22,20 +22,24 @@ import threading | |||
| from collections import namedtuple | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform import events_data | |||
| from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor | |||
| from ..mock import MockLogger | |||
| class MockReservoir: | |||
| """Use this class to replace reservoir.Reservoir in test.""" | |||
| def __init__(self, size): | |||
| self.size = size | |||
| self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'), | |||
| _Tensor('wall_time3', 3, 'value3')] | |||
| self._samples = [ | |||
| _Tensor('wall_time1', 1, 'value1'), | |||
| _Tensor('wall_time2', 2, 'value2'), | |||
| _Tensor('wall_time3', 3, 'value3') | |||
| ] | |||
| def samples(self): | |||
| """Replace the samples function.""" | |||
| @@ -63,11 +67,12 @@ class TestEventsData: | |||
| def setup_method(self): | |||
| """Mock original logger, init a EventsData object for use.""" | |||
| self._ev_data = EventsData() | |||
| self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)], | |||
| 'plugin_name2': [f'tag{i}' for i in range(20, 30)]} | |||
| self._ev_data._tags_by_plugin = { | |||
| 'plugin_name1': [f'tag{i}' for i in range(10)], | |||
| 'plugin_name2': [f'tag{i}' for i in range(20, 30)] | |||
| } | |||
| self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) | |||
| self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), | |||
| 'new_tag': MockReservoir(500)} | |||
| self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)} | |||
| self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] | |||
| def get_ev_data(self): | |||
| @@ -102,8 +107,7 @@ class TestEventsData: | |||
| """Test add_tensor_event success.""" | |||
| ev_data = self.get_ev_data() | |||
| t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', | |||
| value='value1') | |||
| t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1') | |||
| ev_data.add_tensor_event(t_event) | |||
| assert 'tag0' not in ev_data._tags | |||
| @@ -111,6 +115,5 @@ class TestEventsData: | |||
| assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] | |||
| assert 'tag0' not in ev_data._reservoir_by_tag | |||
| assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] | |||
| assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, | |||
| t_event.step, | |||
| assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step, | |||
| t_event.value) | |||
| @@ -19,16 +19,17 @@ Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| import os | |||
| import tempfile | |||
| import shutil | |||
| import tempfile | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from mindinsight.datavisual.data_transform import ms_data_loader | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | |||
| from ..mock import MockLogger | |||
| # bytes of 3 scalar events | |||
| SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' | |||
| b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' | |||
| @@ -74,7 +75,8 @@ class TestMsDataLoader: | |||
| "we will reload all files in path {}.".format(summary_dir) | |||
| shutil.rmtree(summary_dir) | |||
| def test_load_success_with_crc_pass(self, crc_pass): | |||
| @pytest.mark.usefixtures('crc_pass') | |||
| def test_load_success_with_crc_pass(self): | |||
| """Test load success.""" | |||
| summary_dir = tempfile.mkdtemp() | |||
| file1 = os.path.join(summary_dir, 'summary.01') | |||
| @@ -88,7 +90,8 @@ class TestMsDataLoader: | |||
| tensors = ms_loader.get_events_data().tensors(tag[0]) | |||
| assert len(tensors) == 3 | |||
| def test_load_with_crc_fail(self, crc_fail): | |||
| @pytest.mark.usefixtures('crc_fail') | |||
| def test_load_with_crc_fail(self): | |||
| """Test when crc_fail and will not go to func _event_parse.""" | |||
| summary_dir = tempfile.mkdtemp() | |||
| file2 = os.path.join(summary_dir, 'summary.02') | |||
| @@ -100,8 +103,10 @@ class TestMsDataLoader: | |||
| def test_filter_event_files(self): | |||
| """Test filter_event_files function ok.""" | |||
| file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', | |||
| 'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567'] | |||
| file_list = [ | |||
| 'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786', | |||
| 'mysummary.123abce', 'summay.4567' | |||
| ] | |||
| summary_dir = tempfile.mkdtemp() | |||
| for file in file_list: | |||
| with open(os.path.join(summary_dir, file), 'w'): | |||
| @@ -113,6 +118,7 @@ class TestMsDataLoader: | |||
| shutil.rmtree(summary_dir) | |||
| def write_file(filename, record): | |||
| """Write bytes strings to file.""" | |||
| with open(filename, 'wb') as file: | |||
| @@ -19,18 +19,11 @@ Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import tempfile | |||
| from unittest.mock import Mock | |||
| from unittest.mock import patch | |||
| from unittest.mock import Mock, patch | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| @@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs | |||
| from ..mock import MockLogger | |||
| class TestGraphProcessor: | |||
| """Test Graph Processor api.""" | |||
| @@ -70,18 +67,13 @@ class TestGraphProcessor: | |||
| """Load graph record.""" | |||
| summary_base_dir = tempfile.mkdtemp() | |||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | |||
| self._train_id = log_dir.replace(summary_base_dir, ".") | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||
| self._temp_path, self._graph_dict = LogOperations.generate_log( | |||
| PluginNameEnum.GRAPH.value, log_dir, dict(graph_base_path=graph_base_path)) | |||
| log_operation = LogOperations() | |||
| self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) | |||
| self._generated_path.append(summary_base_dir) | |||
| self._mock_data_manager = data_manager.DataManager( | |||
| [DataLoaderGenerator(summary_base_dir)]) | |||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| self._mock_data_manager.start_load_data(reload_interval=0) | |||
| # wait for loading done | |||
| @@ -94,33 +86,29 @@ class TestGraphProcessor: | |||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | |||
| self._train_id = log_dir.replace(summary_base_dir, ".") | |||
| self._temp_path, _, _ = LogOperations.generate_log( | |||
| PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) | |||
| log_operation = LogOperations() | |||
| self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir, | |||
| dict(steps=self._steps_list, tag="image")) | |||
| self._generated_path.append(summary_base_dir) | |||
| self._mock_data_manager = data_manager.DataManager( | |||
| [DataLoaderGenerator(summary_base_dir)]) | |||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| self._mock_data_manager.start_load_data(reload_interval=0) | |||
| # wait for loading done | |||
| check_loading_done(self._mock_data_manager, time_limit=5) | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| def test_get_nodes_with_not_exist_train_id(self, load_graph_record): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| def test_get_nodes_with_not_exist_train_id(self): | |||
| """Test getting nodes with not exist train id.""" | |||
| test_train_id = "not_exist_train_id" | |||
| with pytest.raises(ParamValueError) as exc_info: | |||
| GraphProcessor(test_train_id, self._mock_data_manager) | |||
| assert "Can not find the train job in data manager." in exc_info.value.message | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @patch.object(DataManager, 'get_train_job_by_plugin') | |||
| def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record): | |||
| def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin): | |||
| """Test get nodes with loader is None.""" | |||
| mock_get_train_job_by_plugin.return_value = None | |||
| with pytest.raises(exceptions.SummaryLogPathInvalid): | |||
| @@ -128,15 +116,12 @@ class TestGraphProcessor: | |||
| assert mock_get_train_job_by_plugin.called | |||
| @pytest.mark.parametrize("name, node_type", [ | |||
| ("not_exist_name", "name_scope"), | |||
| ("", "polymeric_scope") | |||
| ]) | |||
| def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")]) | |||
| def test_get_nodes_with_not_exist_name(self, name, node_type): | |||
| """Test getting nodes with not exist name.""" | |||
| with pytest.raises(ParamValueError) as exc_info: | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| graph_processor.get_nodes(name, node_type) | |||
| if name: | |||
| @@ -144,105 +129,99 @@ class TestGraphProcessor: | |||
| else: | |||
| assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message | |||
| @pytest.mark.parametrize("name, node_type, result_file", [ | |||
| (None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), | |||
| ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), | |||
| ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json') | |||
| ]) | |||
| def test_get_nodes_success(self, load_graph_record, name, node_type, result_file): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize( | |||
| "name, node_type, result_file", | |||
| [(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), | |||
| ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), | |||
| ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')]) | |||
| def test_get_nodes_success(self, name, node_type, result_file): | |||
| """Test getting nodes successfully.""" | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| results = graph_processor.get_nodes(name, node_type) | |||
| self.compare_result_with_file(results, result_file) | |||
| @pytest.mark.parametrize("search_content, result_file", [ | |||
| (None, 'test_search_node_names_with_search_content_expected_results1.json'), | |||
| ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), | |||
| ('not_exist_search_content', None) | |||
| ]) | |||
| def test_search_node_names_with_search_content(self, load_graph_record, | |||
| search_content, | |||
| result_file): | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize("search_content, result_file", | |||
| [(None, 'test_search_node_names_with_search_content_expected_results1.json'), | |||
| ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), | |||
| ('not_exist_search_content', None)]) | |||
| def test_search_node_names_with_search_content(self, search_content, result_file): | |||
| """Test search node names with search content.""" | |||
| test_offset = 0 | |||
| test_limit = 1000 | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| results = graph_processor.search_node_names(search_content, | |||
| test_offset, | |||
| test_limit) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| results = graph_processor.search_node_names(search_content, test_offset, test_limit) | |||
| if search_content == 'not_exist_search_content': | |||
| expected_results = {'names': []} | |||
| assert results == expected_results | |||
| else: | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize("offset", [-100, -1]) | |||
| def test_search_node_names_with_negative_offset(self, load_graph_record, offset): | |||
| def test_search_node_names_with_negative_offset(self, offset): | |||
| """Test search node names with negative offset.""" | |||
| test_search_content = "" | |||
| test_limit = 3 | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| with pytest.raises(ParamValueError) as exc_info: | |||
| graph_processor.search_node_names(test_search_content, offset, test_limit) | |||
| assert "'offset' should be greater than or equal to 0." in exc_info.value.message | |||
| @pytest.mark.parametrize("offset, result_file", [ | |||
| (1, 'test_search_node_names_with_offset_expected_results1.json') | |||
| ]) | |||
| def test_search_node_names_with_offset(self, load_graph_record, offset, result_file): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')]) | |||
| def test_search_node_names_with_offset(self, offset, result_file): | |||
| """Test search node names with offset.""" | |||
| test_search_content = "Default/bn1" | |||
| test_offset = offset | |||
| test_limit = 3 | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| results = graph_processor.search_node_names(test_search_content, | |||
| test_offset, | |||
| test_limit) | |||
| self.compare_result_with_file(results, result_file) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| results = graph_processor.search_node_names(test_search_content, test_offset, test_limit) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| def test_search_node_names_with_wrong_limit(self, load_graph_record): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| def test_search_node_names_with_wrong_limit(self): | |||
| """Test search node names with wrong limit.""" | |||
| test_search_content = "" | |||
| test_offset = 0 | |||
| test_limit = 0 | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| with pytest.raises(ParamValueError) as exc_info: | |||
| graph_processor.search_node_names(test_search_content, test_offset, | |||
| test_limit) | |||
| graph_processor.search_node_names(test_search_content, test_offset, test_limit) | |||
| assert "'limit' should in [1, 1000]." in exc_info.value.message | |||
| @pytest.mark.parametrize("name, result_file", [ | |||
| ('Default/bn1', 'test_search_single_node_success_expected_results1.json') | |||
| ]) | |||
| def test_search_single_node_success(self, load_graph_record, name, result_file): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| @pytest.mark.parametrize("name, result_file", | |||
| [('Default/bn1', 'test_search_single_node_success_expected_results1.json')]) | |||
| def test_search_single_node_success(self, name, result_file): | |||
| """Test searching single node successfully.""" | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| results = graph_processor.search_single_node(name) | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| def test_search_single_node_with_not_exist_name(self, load_graph_record): | |||
| @pytest.mark.usefixtures('load_graph_record') | |||
| def test_search_single_node_with_not_exist_name(self): | |||
| """Test searching single node with not exist name.""" | |||
| test_name = "not_exist_name" | |||
| with pytest.raises(exceptions.NodeNotInGraphError): | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||
| graph_processor.search_single_node(test_name) | |||
| def test_check_graph_status_no_graph(self, load_no_graph_record): | |||
| @pytest.mark.usefixtures('load_no_graph_record') | |||
| def test_check_graph_status_no_graph(self): | |||
| """Test checking graph status no graph.""" | |||
| with pytest.raises(ParamValueError) as exc_info: | |||
| GraphProcessor(self._train_id, self._mock_data_manager) | |||
| @@ -22,9 +22,6 @@ import tempfile | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes | |||
| from ..mock import MockLogger | |||
| class TestImagesProcessor: | |||
| """Test images processor api.""" | |||
| @@ -73,12 +74,11 @@ class TestImagesProcessor: | |||
| """ | |||
| summary_base_dir = tempfile.mkdtemp() | |||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | |||
| self._train_id = log_dir.replace(summary_base_dir, ".") | |||
| self._temp_path, self._images_metadata, self._images_values = LogOperations.generate_log( | |||
| log_operation = LogOperations() | |||
| self._temp_path, self._images_metadata, self._images_values = log_operation.generate_log( | |||
| PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) | |||
| self._generated_path.append(summary_base_dir) | |||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| @@ -102,7 +102,8 @@ class TestImagesProcessor: | |||
| """Load image record.""" | |||
| self._init_data_manager(self._cross_steps_list) | |||
| def test_get_metadata_list_with_not_exist_id(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_metadata_list_with_not_exist_id(self): | |||
| """Test getting metadata list with not exist id.""" | |||
| test_train_id = 'not_exist_id' | |||
| image_processor = ImageProcessor(self._mock_data_manager) | |||
| @@ -112,7 +113,8 @@ class TestImagesProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | |||
| def test_get_metadata_list_with_not_exist_tag(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_metadata_list_with_not_exist_tag(self): | |||
| """Test get metadata list with not exist tag.""" | |||
| test_tag_name = 'not_exist_tag_name' | |||
| @@ -124,7 +126,8 @@ class TestImagesProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | |||
| def test_get_metadata_list_success(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_metadata_list_success(self): | |||
| """Test getting metadata list success.""" | |||
| test_tag_name = self._complete_tag_name | |||
| @@ -133,7 +136,8 @@ class TestImagesProcessor: | |||
| assert results == self._images_metadata | |||
| def test_get_single_image_with_not_exist_id(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_single_image_with_not_exist_id(self): | |||
| """Test getting single image with not exist id.""" | |||
| test_train_id = 'not_exist_id' | |||
| test_tag_name = self._complete_tag_name | |||
| @@ -146,7 +150,8 @@ class TestImagesProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | |||
| def test_get_single_image_with_not_exist_tag(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_single_image_with_not_exist_tag(self): | |||
| """Test getting single image with not exist tag.""" | |||
| test_tag_name = 'not_exist_tag_name' | |||
| test_step = self._steps_list[0] | |||
| @@ -159,7 +164,8 @@ class TestImagesProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | |||
| def test_get_single_image_with_not_exist_step(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_single_image_with_not_exist_step(self): | |||
| """Test getting single image with not exist step.""" | |||
| test_tag_name = self._complete_tag_name | |||
| test_step = 10000 | |||
| @@ -172,24 +178,22 @@ class TestImagesProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find the step with given train job id and tag." in exc_info.value.message | |||
| def test_get_single_image_success(self, load_image_record): | |||
| @pytest.mark.usefixtures('load_image_record') | |||
| def test_get_single_image_success(self): | |||
| """Test getting single image successfully.""" | |||
| test_tag_name = self._complete_tag_name | |||
| test_step_index = 0 | |||
| test_step = self._steps_list[test_step_index] | |||
| expected_image_tensor = self._images_values.get(test_step) | |||
| image_processor = ImageProcessor(self._mock_data_manager) | |||
| results = image_processor.get_single_image(self._train_id, test_tag_name, test_step) | |||
| expected_image_tensor = self._images_values.get(test_step) | |||
| image_generator = LogOperations.get_log_generator(PluginNameEnum.IMAGE.value) | |||
| recv_image_tensor = image_generator.get_image_tensor_from_bytes(results) | |||
| recv_image_tensor = get_image_tensor_from_bytes(results) | |||
| assert recv_image_tensor.any() == expected_image_tensor.any() | |||
| def test_reservoir_add_sample(self, load_more_than_limit_image_record): | |||
| @pytest.mark.usefixtures('load_more_than_limit_image_record') | |||
| def test_reservoir_add_sample(self): | |||
| """Test adding sample in reservoir.""" | |||
| test_tag_name = self._complete_tag_name | |||
| @@ -206,7 +210,8 @@ class TestImagesProcessor: | |||
| cnt += 1 | |||
| assert len(self._more_steps_list) - cnt == 10 | |||
| def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record): | |||
| @pytest.mark.usefixtures('load_reservoir_remove_sample_image_record') | |||
| def test_reservoir_remove_sample(self): | |||
| """ | |||
| Test removing sample in reservoir. | |||
| @@ -22,9 +22,6 @@ import tempfile | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||
| from ..mock import MockLogger | |||
| class TestScalarsProcessor: | |||
| """Test scalar processor api.""" | |||
| @@ -65,12 +66,11 @@ class TestScalarsProcessor: | |||
| """Load scalar record.""" | |||
| summary_base_dir = tempfile.mkdtemp() | |||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | |||
| self._train_id = log_dir.replace(summary_base_dir, ".") | |||
| self._temp_path, self._scalars_metadata, self._scalars_values = LogOperations.generate_log( | |||
| log_operation = LogOperations() | |||
| self._temp_path, self._scalars_metadata, self._scalars_values = log_operation.generate_log( | |||
| PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) | |||
| self._generated_path.append(summary_base_dir) | |||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| @@ -79,7 +79,8 @@ class TestScalarsProcessor: | |||
| # wait for loading done | |||
| check_loading_done(self._mock_data_manager, time_limit=5) | |||
| def test_get_metadata_list_with_not_exist_id(self, load_scalar_record): | |||
| @pytest.mark.usefixtures('load_scalar_record') | |||
| def test_get_metadata_list_with_not_exist_id(self): | |||
| """Get metadata list with not exist id.""" | |||
| test_train_id = 'not_exist_id' | |||
| scalar_processor = ScalarsProcessor(self._mock_data_manager) | |||
| @@ -89,7 +90,8 @@ class TestScalarsProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | |||
| def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record): | |||
| @pytest.mark.usefixtures('load_scalar_record') | |||
| def test_get_metadata_list_with_not_exist_tag(self): | |||
| """Get metadata list with not exist tag.""" | |||
| test_tag_name = 'not_exist_tag_name' | |||
| @@ -101,7 +103,8 @@ class TestScalarsProcessor: | |||
| assert exc_info.value.error_code == '50540002' | |||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | |||
| def test_get_metadata_list_success(self, load_scalar_record): | |||
| @pytest.mark.usefixtures('load_scalar_record') | |||
| def test_get_metadata_list_success(self): | |||
| """Get metadata list success.""" | |||
| test_tag_name = self._complete_tag_name | |||
| @@ -18,15 +18,11 @@ Function: | |||
| Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| import os | |||
| import tempfile | |||
| import time | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from tests.ut.datavisual.mock import MockLogger | |||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| @@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||
| from ..mock import MockLogger | |||
| class TestTrainTaskManager: | |||
| """Test train task manager.""" | |||
| @@ -70,39 +70,30 @@ class TestTrainTaskManager: | |||
| @pytest.fixture(scope='function') | |||
| def load_data(self): | |||
| """Load data.""" | |||
| log_operation = LogOperations() | |||
| self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []} | |||
| self._events_names = [] | |||
| self._train_id_list = [] | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||
| self._root_dir = tempfile.mkdtemp() | |||
| for i in range(self._dir_num): | |||
| dir_path = tempfile.mkdtemp(dir=self._root_dir) | |||
| tmp_tag_name = self._tag_name + '_' + str(i) | |||
| event_name = str(i) + "_name" | |||
| train_id = dir_path.replace(self._root_dir, ".") | |||
| # Pass timestamp to write to the same file. | |||
| log_settings = dict( | |||
| steps=self._steps_list, | |||
| tag=tmp_tag_name, | |||
| graph_base_path=graph_base_path, | |||
| time=time.time()) | |||
| log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time()) | |||
| if i % 3 != 0: | |||
| LogOperations.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) | |||
| log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) | |||
| self._plugins_id_map['image'].append(train_id) | |||
| if i % 3 != 1: | |||
| LogOperations.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) | |||
| log_operation.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) | |||
| self._plugins_id_map['scalar'].append(train_id) | |||
| if i % 3 != 2: | |||
| LogOperations.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) | |||
| log_operation.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) | |||
| self._plugins_id_map['graph'].append(train_id) | |||
| self._events_names.append(event_name) | |||
| self._train_id_list.append(train_id) | |||
| self._generated_path.append(self._root_dir) | |||
| @@ -112,7 +103,8 @@ class TestTrainTaskManager: | |||
| check_loading_done(self._mock_data_manager, time_limit=30) | |||
| def test_get_single_train_task_with_not_exists_train_id(self, load_data): | |||
| @pytest.mark.usefixtures('load_data') | |||
| def test_get_single_train_task_with_not_exists_train_id(self): | |||
| """Test getting single train task with not exists train_id.""" | |||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | |||
| for plugin_name in PluginNameEnum.list_members(): | |||
| @@ -124,7 +116,8 @@ class TestTrainTaskManager: | |||
| "the train job in data manager." | |||
| assert exc_info.value.error_code == '50540002' | |||
| def test_get_single_train_task_with_params(self, load_data): | |||
| @pytest.mark.usefixtures('load_data') | |||
| def test_get_single_train_task_with_params(self): | |||
| """Test getting single train task with params.""" | |||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | |||
| for plugin_name in PluginNameEnum.list_members(): | |||
| @@ -138,7 +131,8 @@ class TestTrainTaskManager: | |||
| else: | |||
| assert test_train_id not in self._plugins_id_map.get(plugin_name) | |||
| def test_get_plugins_with_train_id(self, load_data): | |||
| @pytest.mark.usefixtures('load_data') | |||
| def test_get_plugins_with_train_id(self): | |||
| """Test getting plugins with train id.""" | |||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,85 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mask string to crc32.""" | |||
| CRC_TABLE_32 = ( | |||
| 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF, | |||
| 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C, | |||
| 0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, | |||
| 0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, | |||
| 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E, | |||
| 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD, | |||
| 0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696, | |||
| 0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, | |||
| 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3, | |||
| 0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F, | |||
| 0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395, | |||
| 0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, | |||
| 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312, | |||
| 0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE, | |||
| 0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90, | |||
| 0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, | |||
| 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8, | |||
| 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5, | |||
| 0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, | |||
| 0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, | |||
| 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19, | |||
| 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8, | |||
| 0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, | |||
| 0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, | |||
| 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A, | |||
| 0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6, | |||
| 0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82, | |||
| 0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, | |||
| 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351 | |||
| ) | |||
| _CRC = 0 | |||
| _MASK = 0xFFFFFFFF | |||
| def _uint32(x): | |||
| """Transform x's type to uint32.""" | |||
| return x & 0xFFFFFFFF | |||
| def _get_crc_checksum(crc, data): | |||
| """Get crc checksum.""" | |||
| crc ^= _MASK | |||
| for d in data: | |||
| crc_table_index = (crc ^ d) & 0xFF | |||
| crc = (CRC_TABLE_32[crc_table_index] ^ (crc >> 8)) & _MASK | |||
| crc ^= _MASK | |||
| return crc | |||
| def get_mask_from_string(data): | |||
| """ | |||
| Get masked crc from data. | |||
| Args: | |||
| data (byte): Byte string of data. | |||
| Returns: | |||
| uint32, masked crc. | |||
| """ | |||
| crc = _get_crc_checksum(_CRC, data) | |||
| crc = _uint32(crc & _MASK) | |||
| crc = _uint32(((crc >> 15) | _uint32(crc << 17)) + 0xA282EAD8) | |||
| return crc | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,166 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Log generator for images.""" | |||
| import io | |||
| import time | |||
| import numpy as np | |||
| from PIL import Image | |||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| class ImagesLogGenerator(LogGenerator): | |||
| """ | |||
| Log generator for images. | |||
| This is a log generator writing images. User can use it to generate fake | |||
| summary logs about images. | |||
| """ | |||
| def generate_event(self, values): | |||
| """ | |||
| Method for generating image event. | |||
| Args: | |||
| values (dict): A dict contains: | |||
| { | |||
| wall_time (float): Timestamp. | |||
| step (int): Train step. | |||
| image (np.array): Pixels tensor. | |||
| tag (str): Tag name. | |||
| } | |||
| Returns: | |||
| summary_pb2.Event. | |||
| """ | |||
| image_event = summary_pb2.Event() | |||
| image_event.wall_time = values.get('wall_time') | |||
| image_event.step = values.get('step') | |||
| height, width, channel, image_string = self._get_image_string(values.get('image')) | |||
| value = image_event.summary.value.add() | |||
| value.tag = values.get('tag') | |||
| value.image.height = height | |||
| value.image.width = width | |||
| value.image.colorspace = channel | |||
| value.image.encoded_image = image_string | |||
| return image_event | |||
| def _get_image_string(self, image_tensor): | |||
| """ | |||
| Generate image string from tensor. | |||
| Args: | |||
| image_tensor (np.array): Pixels tensor. | |||
| Returns: | |||
| int, height. | |||
| int, width. | |||
| int, channel. | |||
| bytes, image_string. | |||
| """ | |||
| height, width, channel = image_tensor.shape | |||
| scaled_height = int(height) | |||
| scaled_width = int(width) | |||
| image = Image.fromarray(image_tensor) | |||
| image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS) | |||
| output = io.BytesIO() | |||
| image.save(output, format='PNG') | |||
| image_string = output.getvalue() | |||
| output.close() | |||
| return height, width, channel, image_string | |||
| def _make_image_tensor(self, shape): | |||
| """ | |||
| Make image tensor according to shape. | |||
| Args: | |||
| shape (list): Shape of image, consists of height, width, channel. | |||
| Returns: | |||
| np.array, image tensor. | |||
| """ | |||
| image = np.prod(shape) | |||
| image_tensor = (np.arange(image, dtype=float)).reshape(shape) | |||
| image_tensor = image_tensor / np.max(image_tensor) * 255 | |||
| image_tensor = image_tensor.astype(np.uint8) | |||
| return image_tensor | |||
| def generate_log(self, file_path, steps_list, tag_name): | |||
| """ | |||
| Generate log for external calls. | |||
| Args: | |||
| file_path (str): Path to write logs. | |||
| steps_list (list): A list consists of step. | |||
| tag_name (str): Tag name. | |||
| Returns: | |||
| list[dict], generated image metadata. | |||
| dict, generated image tensors. | |||
| """ | |||
| images_values = dict() | |||
| images_metadata = [] | |||
| for step in steps_list: | |||
| wall_time = time.time() | |||
| # height, width, channel | |||
| image_tensor = self._make_image_tensor([5, 5, 3]) | |||
| image_metadata = dict() | |||
| image_metadata.update({'wall_time': wall_time}) | |||
| image_metadata.update({'step': step}) | |||
| image_metadata.update({'height': image_tensor.shape[0]}) | |||
| image_metadata.update({'width': image_tensor.shape[1]}) | |||
| images_metadata.append(image_metadata) | |||
| images_values.update({step: image_tensor}) | |||
| values = dict( | |||
| wall_time=wall_time, | |||
| step=step, | |||
| image=image_tensor, | |||
| tag=tag_name | |||
| ) | |||
| self._write_log_one_step(file_path, values) | |||
| return images_metadata, images_values | |||
| def get_image_tensor_from_bytes(self, image_string): | |||
| """Get image tensor from bytes.""" | |||
| img = Image.open(io.BytesIO(image_string)) | |||
| image_tensor = np.array(img) | |||
| return image_tensor | |||
| if __name__ == "__main__": | |||
| images_log_generator = ImagesLogGenerator() | |||
| test_file_name = '%s.%s.%s' % ('image', 'summary', str(time.time())) | |||
| test_steps = [1, 3, 5] | |||
| test_tags = "test_image_tag_name" | |||
| images_log_generator.generate_log(test_file_name, test_steps, test_tags) | |||
| @@ -1,75 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base log Generator.""" | |||
| import struct | |||
| from abc import abstractmethod | |||
| from tests.ut.datavisual.utils import crc32 | |||
| class LogGenerator: | |||
| """ | |||
| Base log generator. | |||
| This is a base class for log generators. User can use it to generate fake | |||
| summary logs. | |||
| """ | |||
| @abstractmethod | |||
| def generate_event(self, values): | |||
| """ | |||
| Abstract method for generating event. | |||
| Args: | |||
| values (dict): Values. | |||
| Returns: | |||
| summary_pb2.Event. | |||
| """ | |||
| def _write_log_one_step(self, file_path, values): | |||
| """ | |||
| Write log one step. | |||
| Args: | |||
| file_path (str): File path to write. | |||
| values (dict): Values. | |||
| """ | |||
| event = self.generate_event(values) | |||
| self._write_log_from_event(file_path, event) | |||
| @staticmethod | |||
| def _write_log_from_event(file_path, event): | |||
| """ | |||
| Write log by event. | |||
| Args: | |||
| file_path (str): File path to write. | |||
| event (summary_pb2.Event): Event object in proto. | |||
| """ | |||
| send_msg = event.SerializeToString() | |||
| header = struct.pack('<Q', len(send_msg)) | |||
| header_crc = struct.pack('<I', crc32.get_mask_from_string(header)) | |||
| footer_crc = struct.pack('<I', crc32.get_mask_from_string(send_msg)) | |||
| write_event = header + header_crc + send_msg + footer_crc | |||
| with open(file_path, "ab") as f: | |||
| f.write(write_event) | |||
| @@ -1,100 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Log generator for scalars.""" | |||
| import time | |||
| import numpy as np | |||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| class ScalarsLogGenerator(LogGenerator): | |||
| """ | |||
| Log generator for scalars. | |||
| This is a log generator writing scalars. User can use it to generate fake | |||
| summary logs about scalar. | |||
| """ | |||
| def generate_event(self, values): | |||
| """ | |||
| Method for generating scalar event. | |||
| Args: | |||
| values (dict): A dict contains: | |||
| { | |||
| wall_time (float): Timestamp. | |||
| step (int): Train step. | |||
| value (float): Scalar value. | |||
| tag (str): Tag name. | |||
| } | |||
| Returns: | |||
| summary_pb2.Event. | |||
| """ | |||
| scalar_event = summary_pb2.Event() | |||
| scalar_event.wall_time = values.get('wall_time') | |||
| scalar_event.step = values.get('step') | |||
| value = scalar_event.summary.value.add() | |||
| value.tag = values.get('tag') | |||
| value.scalar_value = values.get('value') | |||
| return scalar_event | |||
| def generate_log(self, file_path, steps_list, tag_name): | |||
| """ | |||
| Generate log for external calls. | |||
| Args: | |||
| file_path (str): Path to write logs. | |||
| steps_list (list): A list consists of step. | |||
| tag_name (str): Tag name. | |||
| Returns: | |||
| list[dict], generated scalar metadata. | |||
| None, to be consistent with return value of ImageGenerator. | |||
| """ | |||
| scalars_metadata = [] | |||
| for step in steps_list: | |||
| scalar_metadata = dict() | |||
| wall_time = time.time() | |||
| value = np.random.rand() | |||
| scalar_metadata.update({'wall_time': wall_time}) | |||
| scalar_metadata.update({'step': step}) | |||
| scalar_metadata.update({'value': value}) | |||
| scalars_metadata.append(scalar_metadata) | |||
| scalar_metadata.update({"tag": tag_name}) | |||
| self._write_log_one_step(file_path, scalar_metadata) | |||
| return scalars_metadata, None | |||
| if __name__ == "__main__": | |||
| scalars_log_generator = ScalarsLogGenerator() | |||
| test_file_name = '%s.%s.%s' % ('scalar', 'summary', str(time.time())) | |||
| test_steps = [1, 3, 5] | |||
| test_tag = "test_scalar_tag_name" | |||
| scalars_log_generator.generate_log(test_file_name, test_steps, test_tag) | |||
| @@ -1,83 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """log operations module.""" | |||
| import json | |||
| import os | |||
| import time | |||
| from tests.ut.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator | |||
| from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||
| from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| log_generators = { | |||
| PluginNameEnum.GRAPH.value: GraphLogGenerator(), | |||
| PluginNameEnum.IMAGE.value: ImagesLogGenerator(), | |||
| PluginNameEnum.SCALAR.value: ScalarsLogGenerator() | |||
| } | |||
| class LogOperations: | |||
| """Log Operations class.""" | |||
| @staticmethod | |||
| def generate_log(plugin_name, log_dir, log_settings, valid=True): | |||
| """ | |||
| Generate log. | |||
| Args: | |||
| plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'. | |||
| log_dir (str): Log path to write log. | |||
| log_settings (dict): Info about the log, e.g.: | |||
| { | |||
| current_time (int): Timestamp in summary file name, not necessary. | |||
| graph_base_path (str): Path of graph_bas.json, necessary for `graph`. | |||
| steps (list[int]): Steps for `image` and `scalar`, default is [1]. | |||
| tag (str): Tag name, default is 'default_tag'. | |||
| } | |||
| valid (bool): If true, summary name will be valid. | |||
| Returns: | |||
| str, Summary log path. | |||
| """ | |||
| current_time = log_settings.get('time', int(time.time())) | |||
| current_time = int(current_time) | |||
| log_generator = log_generators.get(plugin_name) | |||
| if valid: | |||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time))) | |||
| else: | |||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time))) | |||
| if plugin_name == PluginNameEnum.GRAPH.value: | |||
| graph_base_path = log_settings.get('graph_base_path') | |||
| with open(graph_base_path, 'r') as load_f: | |||
| graph_dict = json.load(load_f) | |||
| graph_dict = log_generator.generate_log(temp_path, graph_dict) | |||
| return temp_path, graph_dict | |||
| steps_list = log_settings.get('steps', [1]) | |||
| tag_name = log_settings.get('tag', 'default_tag') | |||
| metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name) | |||
| return temp_path, metadata, values | |||
| @staticmethod | |||
| def get_log_generator(plugin_name): | |||
| """Get log generator.""" | |||
| return log_generators.get(plugin_name) | |||
| @@ -1,59 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Description: This file is used for some common util. | |||
| """ | |||
| import os | |||
| import shutil | |||
| import time | |||
| from urllib.parse import urlencode | |||
| from mindinsight.datavisual.common.enums import DataManagerStatus | |||
| def get_url(url, params): | |||
| """ | |||
| Concatenate the URL and params. | |||
| Args: | |||
| url (str): A link requested. For example, http://example.com. | |||
| params (dict): A dict consists of params. For example, {'offset': 1, 'limit':'100}. | |||
| Returns: | |||
| str, like http://example.com?offset=1&limit=100 | |||
| """ | |||
| return url + '?' + urlencode(params) | |||
| def delete_files_or_dirs(path_list): | |||
| """Delete files or dirs in path_list.""" | |||
| for path in path_list: | |||
| if os.path.isdir(path): | |||
| shutil.rmtree(path) | |||
| else: | |||
| os.remove(path) | |||
| def check_loading_done(data_manager, time_limit=15): | |||
| """If loading data for more than `time_limit` seconds, exit.""" | |||
| start_time = time.time() | |||
| while data_manager.status != DataManagerStatus.DONE.value: | |||
| time_used = time.time() - start_time | |||
| if time_used > time_limit: | |||
| break | |||
| time.sleep(0.1) | |||
| continue | |||
| @@ -14,6 +14,6 @@ | |||
| # ============================================================================ | |||
| """Import the mocked mindspore.""" | |||
| import sys | |||
| from .collection.model import mindspore | |||
| from ...utils import mindspore | |||
| sys.modules['mindspore'] = mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock MindSpore Interface.""" | |||
| from .application.model_zoo.resnet import ResNet | |||
| from .common.tensor import Tensor | |||
| from .dataset import MindDataset | |||
| from .nn import * | |||
| from .train.callback import _ListCallback, Callback, RunContext, ModelCheckpoint, SummaryStep | |||
| from .train.summary import SummaryRecord | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,24 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore ResNet class.""" | |||
| from ...nn.cell import Cell | |||
| class ResNet(Cell): | |||
| """Mocked ResNet.""" | |||
| def __init__(self): | |||
| super(ResNet, self).__init__() | |||
| self._cells = {} | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,30 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/common/tensor.py.""" | |||
| import numpy as np | |||
| class Tensor: | |||
| """Mock the MindSpore Tensor class.""" | |||
| def __init__(self, value=0): | |||
| self._value = value | |||
| def asnumpy(self): | |||
| """Get value in numpy format.""" | |||
| return np.array(self._value) | |||
| def __repr__(self): | |||
| return str(self.asnumpy()) | |||
| @@ -1,21 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MindSpore Mock Interface""" | |||
| def get_context(key): | |||
| """Get key in context.""" | |||
| context = {"device_id": 1} | |||
| return context.get(key) | |||
| @@ -1,16 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock mindspore.dataset.""" | |||
| from .engine import MindDataset | |||
| @@ -1,16 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock mindspore.dataset.engine.""" | |||
| from .datasets import MindDataset, Dataset | |||
| @@ -1,36 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/dataset/engine/datasets.py.""" | |||
| class Dataset: | |||
| """Mock the MindSpore Dataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_path=None): | |||
| self.dataset_size = dataset_size | |||
| self.dataset_path = dataset_path | |||
| self.input = [] | |||
| def get_dataset_size(self): | |||
| """Mocked get_dataset_size.""" | |||
| return self.dataset_size | |||
| class MindDataset(Dataset): | |||
| """Mock the MindSpore MindDataset class.""" | |||
| def __init__(self, dataset_size=None, dataset_file=None): | |||
| super(MindDataset, self).__init__(dataset_size) | |||
| self.dataset_file = dataset_file | |||
| @@ -1,22 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the mindspore.nn package.""" | |||
| from .optim import Optimizer, Momentum | |||
| from .loss.loss import SoftmaxCrossEntropyWithLogits, _Loss | |||
| from .cell import Cell, WithLossCell, TrainOneStepWithLossScaleCell | |||
| __all__ = ['Optimizer', 'Momentum', 'SoftmaxCrossEntropyWithLogits', | |||
| '_Loss', 'Cell', 'WithLossCell', | |||
| 'TrainOneStepWithLossScaleCell'] | |||
| @@ -1,51 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/train/callback.py.""" | |||
| class Cell: | |||
| """Mock the Cell class.""" | |||
| def __init__(self, auto_prefix=True, pips=None): | |||
| if pips is None: | |||
| pips = dict() | |||
| self._auto_prefix = auto_prefix | |||
| self._pips = pips | |||
| @property | |||
| def auto_prefix(self): | |||
| """The property of auto_prefix.""" | |||
| return self._auto_prefix | |||
| @property | |||
| def pips(self): | |||
| """The property of pips.""" | |||
| return self._pips | |||
| class WithLossCell(Cell): | |||
| """Mocked WithLossCell class.""" | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False, pips=backbone.pips) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| class TrainOneStepWithLossScaleCell(Cell): | |||
| """Mocked TrainOneStepWithLossScaleCell.""" | |||
| def __init__(self): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__() | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,39 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore SoftmaxCrossEntropyWithLogits class.""" | |||
| from ..cell import Cell | |||
| class _Loss(Cell): | |||
| """Mocked _Loss.""" | |||
| def __init__(self, reduction='mean'): | |||
| super(_Loss, self).__init__() | |||
| self.reduction = reduction | |||
| def construct(self, base, target): | |||
| """Mocked construct function.""" | |||
| raise NotImplementedError | |||
| class SoftmaxCrossEntropyWithLogits(_Loss): | |||
| """Mocked SoftmaxCrossEntropyWithLogits.""" | |||
| def __init__(self, weight=None): | |||
| super(SoftmaxCrossEntropyWithLogits, self).__init__(weight) | |||
| def construct(self, base, target): | |||
| """Mocked construct.""" | |||
| return 1 | |||
| @@ -1,49 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/nn/optim.py.""" | |||
| from .cell import Cell | |||
| class Parameter: | |||
| """Mock the MindSpore Parameter class.""" | |||
| def __init__(self, learning_rate): | |||
| self._name = "Parameter" | |||
| self.default_input = learning_rate | |||
| @property | |||
| def name(self): | |||
| """The property of name.""" | |||
| return self._name | |||
| def __repr__(self): | |||
| format_str = 'Parameter (name={name})' | |||
| return format_str.format(name=self._name) | |||
| class Optimizer(Cell): | |||
| """Mock the MindSpore Optimizer class.""" | |||
| def __init__(self, learning_rate): | |||
| super(Optimizer, self).__init__() | |||
| self.learning_rate = Parameter(learning_rate) | |||
| class Momentum(Optimizer): | |||
| """Mock the MindSpore Momentum class.""" | |||
| def __init__(self, learning_rate): | |||
| super(Momentum, self).__init__(learning_rate) | |||
| self.dynamic_lr = False | |||
| @@ -1,17 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock MindSpore wrap package.""" | |||
| from .loss_scale import TrainOneStepWithLossScaleCell | |||
| from .cell_wrapper import WithLossCell | |||
| @@ -1,25 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock MindSpore cell_wrapper.py.""" | |||
| from ..cell import Cell | |||
| class WithLossCell(Cell): | |||
| """Mock the WithLossCell class.""" | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__() | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| @@ -1,29 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock MindSpore loss_scale.py.""" | |||
| from ..cell import Cell | |||
| class TrainOneStepWithLossScaleCell(Cell): | |||
| """Mock the TrainOneStepWithLossScaleCell class.""" | |||
| def __init__(self, network, optimizer): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__() | |||
| self.network = network | |||
| self.optimizer = optimizer | |||
| def construct(self, data, label): | |||
| """Mock the construct method.""" | |||
| raise NotImplementedError | |||
| @@ -1,14 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -1,84 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Mock the MindSpore mindspore/train/callback.py.""" | |||
| import os | |||
| class RunContext: | |||
| """Mock the RunContext class.""" | |||
| def __init__(self, original_args=None): | |||
| self._original_args = original_args | |||
| self._stop_requested = False | |||
| def original_args(self): | |||
| """Mock original_args.""" | |||
| return self._original_args | |||
| def stop_requested(self): | |||
| """Mock stop_requested method.""" | |||
| return self._stop_requested | |||
| class Callback: | |||
| """Mock the Callback class.""" | |||
| def __init__(self): | |||
| pass | |||
| def begin(self, run_context): | |||
| """Called once before network training.""" | |||
| def epoch_begin(self, run_context): | |||
| """Called before each epoch begin.""" | |||
| class _ListCallback(Callback): | |||
| """Mock the _ListCallabck class.""" | |||
| def __init__(self, callbacks): | |||
| super(_ListCallback, self).__init__() | |||
| self._callbacks = callbacks | |||
| class ModelCheckpoint(Callback): | |||
| """Mock the ModelCheckpoint class.""" | |||
| def __init__(self, prefix='CKP', directory=None, config=None): | |||
| super(ModelCheckpoint, self).__init__() | |||
| self._prefix = prefix | |||
| self._directory = directory | |||
| self._config = config | |||
| self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt') | |||
| @property | |||
| def model_file_name(self): | |||
| """Get the file name of model.""" | |||
| return self._model_file_name | |||
| @property | |||
| def latest_ckpt_file_name(self): | |||
| """Get the latest file name fo checkpoint.""" | |||
| return self._latest_ckpt_file_name | |||
| class SummaryStep(Callback): | |||
| """Mock the SummaryStep class.""" | |||
| def __init__(self, summary, flush_step=10): | |||
| super(SummaryStep, self).__init__() | |||
| self._sumamry = summary | |||
| self._flush_step = flush_step | |||
| self.summary_file_name = summary.full_file_name | |||
| @@ -1,18 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MindSpore Mock Interface""" | |||
| from .summary_record import SummaryRecord | |||
| __all__ = ["SummaryRecord"] | |||
| @@ -1,38 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MindSpore Mock Interface""" | |||
| import os | |||
| import time | |||
| class SummaryRecord: | |||
| """Mock the MindSpore SummaryRecord class.""" | |||
| def __init__(self, | |||
| log_dir: str, | |||
| file_prefix: str = "events.", | |||
| file_suffix: str = ".MS", | |||
| create_time=int(time.time())): | |||
| self.log_dir = log_dir | |||
| self.prefix = file_prefix | |||
| self.suffix = file_suffix | |||
| file_name = file_prefix + 'summary.' + str(create_time) + file_suffix | |||
| self.full_file_name = os.path.join(log_dir, file_name) | |||
| def flush(self): | |||
| """Mock flush method.""" | |||
| def close(self): | |||
| """Mock close method.""" | |||
| @@ -16,21 +16,21 @@ | |||
| import os | |||
| import shutil | |||
| import unittest | |||
| from unittest import mock, TestCase | |||
| from unittest import TestCase, mock | |||
| from unittest.mock import MagicMock | |||
| from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ | |||
| AnalyzeObject | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageLogError, LineageGetModelFileError, MindInsightException | |||
| from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError, | |||
| MindInsightException) | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.dataset.engine import MindDataset, Dataset | |||
| from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \ | |||
| SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep | |||
| from mindspore.dataset.engine import Dataset, MindDataset | |||
| from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell | |||
| from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep | |||
| from mindspore.train.summary import SummaryRecord | |||
| @mock.patch('builtins.open') | |||
| @mock.patch('os.makedirs') | |||
| class TestModelLineage(TestCase): | |||
| """Test TrainLineage and EvalLineage class in model_lineage.py.""" | |||
| @@ -51,23 +51,19 @@ class TestModelLineage(TestCase): | |||
| cls.summary_log_path = '/path/to/summary_log' | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| def test_summary_record_exception(self, mock_validate_summary): | |||
| def test_summary_record_exception(self, *args): | |||
| """Test SummaryRecord with exception.""" | |||
| mock_validate_summary.return_value = None | |||
| args[0].return_value = None | |||
| summary_record = self.my_summary_record(self.summary_log_path) | |||
| with self.assertRaises(MindInsightException) as context: | |||
| self.my_train_module(summary_record=summary_record, raise_exception=1) | |||
| self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception)) | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'LineageSummary.record_dataset_graph') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'validate_summary_record') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'AnalyzeObject.get_optimizer_by_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | |||
| def test_begin(self, *args): | |||
| """Test TrainLineage.begin method.""" | |||
| @@ -82,14 +78,10 @@ class TestModelLineage(TestCase): | |||
| args[4].assert_called() | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'LineageSummary.record_dataset_graph') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'validate_summary_record') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'AnalyzeObject.get_optimizer_by_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||
| 'AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | |||
| def test_begin_error(self, *args): | |||
| """Test TrainLineage.begin method.""" | |||
| @@ -122,15 +114,11 @@ class TestModelLineage(TestCase): | |||
| train_lineage.begin(self.my_run_context(run_context)) | |||
| self.assertTrue('The parameter optimizer is invalid.' in str(context.exception)) | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | |||
| @mock.patch('builtins.float') | |||
| @@ -150,23 +138,19 @@ class TestModelLineage(TestCase): | |||
| args[6].assert_called() | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| def test_train_end_exception(self, mock_validate_summary): | |||
| def test_train_end_exception(self, *args): | |||
| """Test TrainLineage.end method when exception.""" | |||
| mock_validate_summary.return_value = True | |||
| args[0].return_value = True | |||
| train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True) | |||
| with self.assertRaises(Exception) as context: | |||
| train_lineage.end(self.run_context) | |||
| self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception)) | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | |||
| @mock.patch('builtins.float') | |||
| @@ -186,15 +170,11 @@ class TestModelLineage(TestCase): | |||
| train_lineage.end(self.my_run_context(self.run_context)) | |||
| self.assertTrue('End error in TrainLineage:' in str(context.exception)) | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch( | |||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | |||
| @mock.patch('builtins.float') | |||
| @@ -218,9 +198,9 @@ class TestModelLineage(TestCase): | |||
| self.assertTrue('End error in TrainLineage:' in str(context.exception)) | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| def test_eval_exception_train_id_none(self, mock_validate_summary): | |||
| def test_eval_exception_train_id_none(self, *args): | |||
| """Test EvalLineage.end method with initialization error.""" | |||
| mock_validate_summary.return_value = True | |||
| args[0].return_value = True | |||
| with self.assertRaises(MindInsightException) as context: | |||
| self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2) | |||
| self.assertTrue('Invalid value for raise_exception.' in str(context.exception)) | |||
| @@ -242,9 +222,9 @@ class TestModelLineage(TestCase): | |||
| args[0].assert_called() | |||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||
| def test_eval_end_except_run_context(self, mock_validate_summary): | |||
| def test_eval_end_except_run_context(self, *args): | |||
| """Test EvalLineage.end method when run_context is invalid..""" | |||
| mock_validate_summary.return_value = True | |||
| args[0].return_value = True | |||
| eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True) | |||
| with self.assertRaises(Exception) as context: | |||
| eval_lineage.end(self.run_context) | |||
| @@ -284,8 +264,9 @@ class TestModelLineage(TestCase): | |||
| eval_lineage.end(self.my_run_context(self.run_context)) | |||
| self.assertTrue('End error in EvalLineage' in str(context.exception)) | |||
| def test_epoch_is_zero(self): | |||
| def test_epoch_is_zero(self, *args): | |||
| """Test TrainLineage.end method.""" | |||
| args[0].return_value = None | |||
| run_context = self.run_context | |||
| run_context['epoch_num'] = 0 | |||
| with self.assertRaises(MindInsightException): | |||
| @@ -345,7 +326,7 @@ class TestAnalyzer(TestCase): | |||
| ) | |||
| res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') | |||
| res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') | |||
| assert res1 == {'step_num': 10, | |||
| assert res1 == {'step_num': 10, | |||
| 'train_dataset_path': '/path/to', | |||
| 'train_dataset_size': 50, | |||
| 'epoch': 2} | |||
| @@ -20,23 +20,44 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.lineagemgr.common.path_parser import SummaryPathParser | |||
| MOCK_SUMMARY_DIRS = [ | |||
| { | |||
| 'relative_path': './relative_path0' | |||
| }, | |||
| { | |||
| 'relative_path': './' | |||
| }, | |||
| { | |||
| 'relative_path': './relative_path1' | |||
| } | |||
| ] | |||
| MOCK_SUMMARIES = [ | |||
| { | |||
| 'file_name': 'file0', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file0_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file1', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| }, | |||
| { | |||
| 'file_name': 'file1_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| } | |||
| ] | |||
| class TestSummaryPathParser(TestCase): | |||
| """Test the class of SummaryPathParser.""" | |||
| @mock.patch.object(SummaryWatcher, 'list_summary_directories') | |||
| def test_get_summary_dirs(self, *args): | |||
| """Test the function of get_summary_dirs.""" | |||
| args[0].return_value = [ | |||
| { | |||
| 'relative_path': './relative_path0' | |||
| }, | |||
| { | |||
| 'relative_path': './' | |||
| }, | |||
| { | |||
| 'relative_path': './relative_path1' | |||
| } | |||
| ] | |||
| args[0].return_value = MOCK_SUMMARY_DIRS | |||
| expected_result = [ | |||
| '/path/to/base/relative_path0', | |||
| @@ -54,24 +75,7 @@ class TestSummaryPathParser(TestCase): | |||
| @mock.patch.object(SummaryWatcher, 'list_summaries') | |||
| def test_get_latest_lineage_summary(self, *args): | |||
| """Test the function of get_latest_lineage_summary.""" | |||
| args[0].return_value = [ | |||
| { | |||
| 'file_name': 'file0', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file0_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file1', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| }, | |||
| { | |||
| 'file_name': 'file1_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| } | |||
| ] | |||
| args[0].return_value = MOCK_SUMMARIES | |||
| summary_dir = '/path/to/summary_dir' | |||
| result = SummaryPathParser.get_latest_lineage_summary(summary_dir) | |||
| self.assertEqual('/path/to/summary_dir/file1_lineage', result) | |||
| @@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase): | |||
| @mock.patch.object(SummaryWatcher, 'list_summary_directories') | |||
| def test_get_latest_lineage_summaries(self, *args): | |||
| """Test the function of get_latest_lineage_summaries.""" | |||
| args[0].return_value = [ | |||
| { | |||
| 'relative_path': './relative_path0' | |||
| }, | |||
| { | |||
| 'relative_path': './' | |||
| }, | |||
| { | |||
| 'relative_path': './relative_path1' | |||
| } | |||
| ] | |||
| args[1].return_value = [ | |||
| { | |||
| 'file_name': 'file0', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file0_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031970) | |||
| }, | |||
| { | |||
| 'file_name': 'file1', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| }, | |||
| { | |||
| 'file_name': 'file1_lineage', | |||
| 'create_time': datetime.fromtimestamp(1582031971) | |||
| } | |||
| ] | |||
| args[0].return_value = MOCK_SUMMARY_DIRS | |||
| args[1].return_value = MOCK_SUMMARIES | |||
| expected_result = [ | |||
| '/path/to/base/relative_path0/file1_lineage', | |||
| @@ -15,38 +15,31 @@ | |||
| """Test the validate module.""" | |||
| from unittest import TestCase | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageParamValueError, LineageParamTypeError | |||
| from mindinsight.lineagemgr.common.validator.model_parameter import \ | |||
| SearchModelConditionParameter | |||
| from mindinsight.lineagemgr.common.validator.validate import \ | |||
| validate_search_model_condition | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError | |||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | |||
| from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| class TestValidateSearchModelCondition(TestCase): | |||
| """Test the mothod of validate_search_model_condition.""" | |||
| def test_validate_search_model_condition(self): | |||
| """Test the mothod of validate_search_model_condition.""" | |||
| def test_validate_search_model_condition_param_type_error(self): | |||
| """Test the mothod of validate_search_model_condition with LineageParamTypeError.""" | |||
| condition = { | |||
| 'summary_dir': 'xxx' | |||
| } | |||
| self.assertRaisesRegex( | |||
| LineageParamTypeError, | |||
| self._assert_raise_of_lineage_param_type_error( | |||
| 'The search_condition element summary_dir should be dict.', | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| def test_validate_search_model_condition_param_value_error(self): | |||
| """Test the mothod of validate_search_model_condition with LineageParamValueError.""" | |||
| condition = { | |||
| 'xxx': 'xxx' | |||
| } | |||
| self.assertRaisesRegex( | |||
| LineageParamValueError, | |||
| self._assert_raise_of_lineage_param_value_error( | |||
| 'The search attribute not supported.', | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -55,22 +48,38 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'xxx': 'xxx' | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| LineageParamValueError, | |||
| self._assert_raise_of_lineage_param_value_error( | |||
| "The compare condition should be in", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| condition = { | |||
| 1: { | |||
| "ge": 8.0 | |||
| } | |||
| } | |||
| self._assert_raise_of_lineage_param_value_error( | |||
| "The search attribute not supported.", | |||
| condition | |||
| ) | |||
| condition = { | |||
| 'metric_': { | |||
| "ge": 8.0 | |||
| } | |||
| } | |||
| self._assert_raise_of_lineage_param_value_error( | |||
| "The search attribute not supported.", | |||
| condition | |||
| ) | |||
| def test_validate_search_model_condition_mindinsight_exception_1(self): | |||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||
| condition = { | |||
| "offset": 100001 | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "Invalid input offset. 0 <= offset <= 100000", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -80,11 +89,9 @@ class TestValidateSearchModelCondition(TestCase): | |||
| }, | |||
| 'limit': 10 | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter summary_dir is invalid. It should be a dict and the value should be a string", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter summary_dir is invalid. It should be a dict and " | |||
| "the value should be a string", | |||
| condition | |||
| ) | |||
| @@ -93,11 +100,9 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'in': 1.0 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter learning_rate is invalid. It should be a dict and " | |||
| "the value should be a float or a integer", | |||
| condition | |||
| ) | |||
| @@ -106,24 +111,22 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'lt': True | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter learning_rate is invalid. It should be a dict and " | |||
| "the value should be a float or a integer", | |||
| condition | |||
| ) | |||
| def test_validate_search_model_condition_mindinsight_exception_2(self): | |||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||
| condition = { | |||
| 'learning_rate': { | |||
| 'gt': [1.0] | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter learning_rate is invalid. It should be a dict and " | |||
| "the value should be a float or a integer", | |||
| condition | |||
| ) | |||
| @@ -132,11 +135,9 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'ge': 1 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter loss_function is invalid. It should be a dict and the value should be a string", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter loss_function is invalid. It should be a dict and " | |||
| "the value should be a string", | |||
| condition | |||
| ) | |||
| @@ -145,12 +146,9 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'in': 2 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter train_dataset_count is invalid. It should be a dict " | |||
| "and the value should be a integer between 0", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -162,14 +160,14 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'eq': 'xxx' | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter network is invalid. It should be a dict and the value should be a string", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter network is invalid. It should be a dict and " | |||
| "the value should be a string", | |||
| condition | |||
| ) | |||
| def test_validate_search_model_condition_mindinsight_exception_3(self): | |||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||
| condition = { | |||
| 'batch_size': { | |||
| 'lt': 2, | |||
| @@ -179,11 +177,8 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'eq': 222 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter batch_size is invalid. It should be a non-negative integer.", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -192,12 +187,9 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'lt': -2 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter test_dataset_count is invalid. It should be a dict " | |||
| "and the value should be a integer between 0", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -206,11 +198,8 @@ class TestValidateSearchModelCondition(TestCase): | |||
| 'lt': False | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter epoch is invalid. It should be a positive integer.", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| @@ -219,65 +208,79 @@ class TestValidateSearchModelCondition(TestCase): | |||
| "ge": "" | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter learning_rate is invalid. It should be a dict and " | |||
| "the value should be a float or a integer", | |||
| condition | |||
| ) | |||
| def test_validate_search_model_condition_mindinsight_exception_4(self): | |||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||
| condition = { | |||
| "train_dataset_count": { | |||
| "ge": 8.0 | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter train_dataset_count is invalid. It should be a dict " | |||
| "and the value should be a integer between 0", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| condition = { | |||
| 1: { | |||
| "ge": 8.0 | |||
| 'metric_attribute': { | |||
| 'ge': 'xxx' | |||
| } | |||
| } | |||
| self.assertRaisesRegex( | |||
| LineageParamValueError, | |||
| "The search attribute not supported.", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| self._assert_raise_of_mindinsight_exception( | |||
| "The parameter metric_attribute is invalid. " | |||
| "It should be a dict and the value should be a float or a integer", | |||
| condition | |||
| ) | |||
| condition = { | |||
| 'metric_': { | |||
| "ge": 8.0 | |||
| } | |||
| } | |||
| LineageParamValueError('The search attribute not supported.') | |||
| self.assertRaisesRegex( | |||
| LineageParamValueError, | |||
| "The search attribute not supported.", | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| def _assert_raise(self, exception, msg, condition): | |||
| """ | |||
| Assert raise by unittest. | |||
| condition = { | |||
| 'metric_attribute': { | |||
| 'ge': 'xxx' | |||
| } | |||
| } | |||
| Args: | |||
| exception (Type): Exception class expected to be raised. | |||
| msg (msg): Expected error message. | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self.assertRaisesRegex( | |||
| MindInsightException, | |||
| "The parameter metric_attribute is invalid. " | |||
| "It should be a dict and the value should be a float or a integer", | |||
| exception, | |||
| msg, | |||
| validate_search_model_condition, | |||
| SearchModelConditionParameter, | |||
| condition | |||
| ) | |||
| def _assert_raise_of_mindinsight_exception(self, msg, condition): | |||
| """ | |||
| Assert raise of MindinsightException by unittest. | |||
| Args: | |||
| msg (msg): Expected error message. | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self._assert_raise(MindInsightException, msg, condition) | |||
| def _assert_raise_of_lineage_param_value_error(self, msg, condition): | |||
| """ | |||
| Assert raise of LineageParamValueError by unittest. | |||
| Args: | |||
| msg (msg): Expected error message. | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self._assert_raise(LineageParamValueError, msg, condition) | |||
| def _assert_raise_of_lineage_param_type_error(self, msg, condition): | |||
| """ | |||
| Assert raise of LineageParamTypeError by unittest. | |||
| Args: | |||
| msg (msg): Expected error message. | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self._assert_raise(LineageParamTypeError, msg, condition) | |||
| @@ -15,6 +15,8 @@ | |||
| """The event data in querier test.""" | |||
| import json | |||
| from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE | |||
| EVENT_TRAIN_DICT_0 = { | |||
| 'wall_time': 1581499557.7017336, | |||
| 'train_lineage': { | |||
| @@ -373,49 +375,4 @@ EVENT_DATASET_DICT_0 = { | |||
| } | |||
| } | |||
| DATASET_DICT_0 = { | |||
| 'op_type': 'BatchDataset', | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'num_parallel_workers': None, | |||
| 'drop_remainder': True, | |||
| 'batch_size': 10, | |||
| 'children': [ | |||
| { | |||
| 'op_type': 'MapDataset', | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'num_parallel_workers': None, | |||
| 'input_columns': [ | |||
| 'label' | |||
| ], | |||
| 'output_columns': [ | |||
| None | |||
| ], | |||
| 'operations': [ | |||
| { | |||
| 'tensor_op_module': 'minddata.transforms.c_transforms', | |||
| 'tensor_op_name': 'OneHot', | |||
| 'num_classes': 10 | |||
| } | |||
| ], | |||
| 'children': [ | |||
| { | |||
| 'op_type': 'MnistDataset', | |||
| 'shard_id': None, | |||
| 'num_shards': None, | |||
| 'op_module': 'minddata.dataengine.datasets', | |||
| 'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData', | |||
| 'num_parallel_workers': None, | |||
| 'shuffle': None, | |||
| 'num_samples': 100, | |||
| 'sampler': { | |||
| 'sampler_module': 'minddata.dataengine.samplers', | |||
| 'sampler_name': 'RandomSampler', | |||
| 'replacement': True, | |||
| 'num_samples': 100 | |||
| }, | |||
| 'children': [] | |||
| } | |||
| ] | |||
| } | |||
| ] | |||
| } | |||
| DATASET_DICT_0 = SERIALIZED_PIPELINE | |||
| @@ -18,12 +18,12 @@ from unittest import TestCase, mock | |||
| from google.protobuf.json_format import ParseDict | |||
| import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageQuerierParamException, LineageParamTypeError, \ | |||
| LineageSummaryAnalyzeException, LineageSummaryParseException | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException, | |||
| LineageSummaryAnalyzeException, | |||
| LineageSummaryParseException) | |||
| from mindinsight.lineagemgr.querier.querier import Querier | |||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ | |||
| LineageInfo | |||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | |||
| from . import event_data | |||
| @@ -140,6 +140,98 @@ def get_lineage_infos(): | |||
| return lineage_infos | |||
| LINEAGE_INFO_0 = { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| LINEAGE_INFO_1 = { | |||
| 'summary_dir': '/path/to/summary1', | |||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||
| 'metric': event_data.METRIC_1, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| LINEAGE_FILTRATION_0 = create_filtration_result( | |||
| '/path/to/summary0', | |||
| event_data.EVENT_TRAIN_DICT_0, | |||
| event_data.EVENT_EVAL_DICT_0, | |||
| event_data.METRIC_0, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_1 = create_filtration_result( | |||
| '/path/to/summary1', | |||
| event_data.EVENT_TRAIN_DICT_1, | |||
| event_data.EVENT_EVAL_DICT_1, | |||
| event_data.METRIC_1, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_2 = create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_3 = create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_4 = create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_5 = { | |||
| "summary_dir": '/path/to/summary5', | |||
| "loss_function": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||
| "test_dataset_path": None, | |||
| "test_dataset_count": None, | |||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||
| "learning_rate": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||
| "metric": {}, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| } | |||
| LINEAGE_FILTRATION_6 = { | |||
| "summary_dir": '/path/to/summary6', | |||
| "loss_function": None, | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": None, | |||
| "test_dataset_path": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||
| "test_dataset_count": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||
| "network": None, | |||
| "optimizer": None, | |||
| "learning_rate": None, | |||
| "epoch": None, | |||
| "batch_size": None, | |||
| "loss": None, | |||
| "model_size": None, | |||
| "metric": event_data.METRIC_5, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| } | |||
| class TestQuerier(TestCase): | |||
| """Test the class of `Querier`.""" | |||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | |||
| @@ -169,31 +261,13 @@ class TestQuerier(TestCase): | |||
| def test_get_summary_lineage_success_1(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| ] | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| def test_get_summary_lineage_success_2(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': | |||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][ | |||
| 'valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| ] | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary0' | |||
| ) | |||
| @@ -216,20 +290,8 @@ class TestQuerier(TestCase): | |||
| def test_get_summary_lineage_success_4(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| }, | |||
| { | |||
| 'summary_dir': '/path/to/summary1', | |||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||
| 'metric': event_data.METRIC_1, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| }, | |||
| LINEAGE_INFO_0, | |||
| LINEAGE_INFO_1, | |||
| { | |||
| 'summary_dir': '/path/to/summary2', | |||
| **event_data.EVENT_TRAIN_DICT_2['train_lineage'], | |||
| @@ -274,15 +336,7 @@ class TestQuerier(TestCase): | |||
| def test_get_summary_lineage_success_5(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary1', | |||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||
| 'metric': event_data.METRIC_1, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| ] | |||
| expected_result = [LINEAGE_INFO_1] | |||
| result = self.multi_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary1' | |||
| ) | |||
| @@ -341,20 +395,8 @@ class TestQuerier(TestCase): | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary1', | |||
| event_data.EVENT_TRAIN_DICT_1, | |||
| event_data.EVENT_EVAL_DICT_1, | |||
| event_data.METRIC_1, | |||
| event_data.DATASET_DICT_0, | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_1, | |||
| LINEAGE_FILTRATION_2 | |||
| ], | |||
| 'count': 2, | |||
| } | |||
| @@ -377,20 +419,8 @@ class TestQuerier(TestCase): | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_2, | |||
| LINEAGE_FILTRATION_3 | |||
| ], | |||
| 'count': 2, | |||
| } | |||
| @@ -405,20 +435,8 @@ class TestQuerier(TestCase): | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ) | |||
| LINEAGE_FILTRATION_2, | |||
| LINEAGE_FILTRATION_3 | |||
| ], | |||
| 'count': 7, | |||
| } | |||
| @@ -429,82 +447,13 @@ class TestQuerier(TestCase): | |||
| """Test the success of filter_summary_lineage.""" | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary0', | |||
| event_data.EVENT_TRAIN_DICT_0, | |||
| event_data.EVENT_EVAL_DICT_0, | |||
| event_data.METRIC_0, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary1', | |||
| event_data.EVENT_TRAIN_DICT_1, | |||
| event_data.EVENT_EVAL_DICT_1, | |||
| event_data.METRIC_1, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| { | |||
| "summary_dir": '/path/to/summary5', | |||
| "loss_function": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||
| "test_dataset_path": None, | |||
| "test_dataset_count": None, | |||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||
| "learning_rate": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||
| "metric": {}, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| }, | |||
| { | |||
| "summary_dir": '/path/to/summary6', | |||
| "loss_function": None, | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": None, | |||
| "test_dataset_path": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||
| "test_dataset_count": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||
| "network": None, | |||
| "optimizer": None, | |||
| "learning_rate": None, | |||
| "epoch": None, | |||
| "batch_size": None, | |||
| "loss": None, | |||
| "model_size": None, | |||
| "metric": event_data.METRIC_5, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| } | |||
| LINEAGE_FILTRATION_0, | |||
| LINEAGE_FILTRATION_1, | |||
| LINEAGE_FILTRATION_2, | |||
| LINEAGE_FILTRATION_3, | |||
| LINEAGE_FILTRATION_4, | |||
| LINEAGE_FILTRATION_5, | |||
| LINEAGE_FILTRATION_6 | |||
| ], | |||
| 'count': 7, | |||
| } | |||
| @@ -519,15 +468,7 @@ class TestQuerier(TestCase): | |||
| } | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| ], | |||
| 'object': [LINEAGE_FILTRATION_4], | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| @@ -541,82 +482,13 @@ class TestQuerier(TestCase): | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary0', | |||
| event_data.EVENT_TRAIN_DICT_0, | |||
| event_data.EVENT_EVAL_DICT_0, | |||
| event_data.METRIC_0, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| { | |||
| "summary_dir": '/path/to/summary5', | |||
| "loss_function": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||
| "test_dataset_path": None, | |||
| "test_dataset_count": None, | |||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||
| "learning_rate": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||
| "metric": {}, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| }, | |||
| create_filtration_result( | |||
| '/path/to/summary1', | |||
| event_data.EVENT_TRAIN_DICT_1, | |||
| event_data.EVENT_EVAL_DICT_1, | |||
| event_data.METRIC_1, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| { | |||
| "summary_dir": '/path/to/summary6', | |||
| "loss_function": None, | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": None, | |||
| "test_dataset_path": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||
| "test_dataset_count": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||
| "network": None, | |||
| "optimizer": None, | |||
| "learning_rate": None, | |||
| "epoch": None, | |||
| "batch_size": None, | |||
| "loss": None, | |||
| "model_size": None, | |||
| "metric": event_data.METRIC_5, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| } | |||
| LINEAGE_FILTRATION_0, | |||
| LINEAGE_FILTRATION_5, | |||
| LINEAGE_FILTRATION_1, | |||
| LINEAGE_FILTRATION_2, | |||
| LINEAGE_FILTRATION_3, | |||
| LINEAGE_FILTRATION_4, | |||
| LINEAGE_FILTRATION_6 | |||
| ], | |||
| 'count': 7, | |||
| } | |||
| @@ -631,82 +503,13 @@ class TestQuerier(TestCase): | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| { | |||
| "summary_dir": '/path/to/summary6', | |||
| "loss_function": None, | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": None, | |||
| "test_dataset_path": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||
| "test_dataset_count": | |||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||
| "network": None, | |||
| "optimizer": None, | |||
| "learning_rate": None, | |||
| "epoch": None, | |||
| "batch_size": None, | |||
| "loss": None, | |||
| "model_size": None, | |||
| "metric": event_data.METRIC_5, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| }, | |||
| create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary3', | |||
| event_data.EVENT_TRAIN_DICT_3, | |||
| event_data.EVENT_EVAL_DICT_3, | |||
| event_data.METRIC_3, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary2', | |||
| event_data.EVENT_TRAIN_DICT_2, | |||
| event_data.EVENT_EVAL_DICT_2, | |||
| event_data.METRIC_2, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary1', | |||
| event_data.EVENT_TRAIN_DICT_1, | |||
| event_data.EVENT_EVAL_DICT_1, | |||
| event_data.METRIC_1, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| create_filtration_result( | |||
| '/path/to/summary0', | |||
| event_data.EVENT_TRAIN_DICT_0, | |||
| event_data.EVENT_EVAL_DICT_0, | |||
| event_data.METRIC_0, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| { | |||
| "summary_dir": '/path/to/summary5', | |||
| "loss_function": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||
| "train_dataset_path": None, | |||
| "train_dataset_count": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||
| "test_dataset_path": None, | |||
| "test_dataset_count": None, | |||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||
| "learning_rate": | |||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||
| "metric": {}, | |||
| "dataset_graph": event_data.DATASET_DICT_0, | |||
| "dataset_mark": '2' | |||
| } | |||
| LINEAGE_FILTRATION_6, | |||
| LINEAGE_FILTRATION_4, | |||
| LINEAGE_FILTRATION_3, | |||
| LINEAGE_FILTRATION_2, | |||
| LINEAGE_FILTRATION_1, | |||
| LINEAGE_FILTRATION_0, | |||
| LINEAGE_FILTRATION_5 | |||
| ], | |||
| 'count': 7, | |||
| } | |||
| @@ -722,15 +525,7 @@ class TestQuerier(TestCase): | |||
| } | |||
| } | |||
| expected_result = { | |||
| 'object': [ | |||
| create_filtration_result( | |||
| '/path/to/summary4', | |||
| event_data.EVENT_TRAIN_DICT_4, | |||
| event_data.EVENT_EVAL_DICT_4, | |||
| event_data.METRIC_4, | |||
| event_data.DATASET_DICT_0 | |||
| ), | |||
| ], | |||
| 'object': [LINEAGE_FILTRATION_4], | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| @@ -809,20 +604,8 @@ class TestQuerier(TestCase): | |||
| querier = Querier(summary_path) | |||
| querier._parse_failed_paths.append('/path/to/summary1/log1') | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| }, | |||
| { | |||
| 'summary_dir': '/path/to/summary1', | |||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||
| 'metric': event_data.METRIC_1, | |||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| LINEAGE_INFO_0, | |||
| LINEAGE_INFO_1 | |||
| ] | |||
| result = querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| @@ -842,17 +625,7 @@ class TestQuerier(TestCase): | |||
| querier._parse_failed_paths.append('/path/to/summary1/log1') | |||
| args[0].return_value = create_lineage_info(None, None, None) | |||
| expected_result = [ | |||
| { | |||
| 'summary_dir': '/path/to/summary0', | |||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||
| 'metric': event_data.METRIC_0, | |||
| 'valid_dataset': | |||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][ | |||
| 'valid_dataset'], | |||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||
| } | |||
| ] | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| self.assertListEqual( | |||
| @@ -15,11 +15,12 @@ | |||
| """Test the query_model module.""" | |||
| from unittest import TestCase | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||
| LineageEventNotExistException, LineageEventFieldNotExistException | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException, | |||
| LineageEventNotExistException) | |||
| from mindinsight.lineagemgr.querier.query_model import LineageObj | |||
| from . import event_data | |||
| from .test_querier import create_lineage_info, create_filtration_result | |||
| from .test_querier import create_filtration_result, create_lineage_info | |||
| class TestLineageObj(TestCase): | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -19,10 +19,10 @@ import time | |||
| from google.protobuf import json_format | |||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| from .log_generator import LogGenerator | |||
| class GraphLogGenerator(LogGenerator): | |||
| """ | |||
| @@ -74,7 +74,7 @@ class GraphLogGenerator(LogGenerator): | |||
| if __name__ == "__main__": | |||
| graph_log_generator = GraphLogGenerator() | |||
| test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators--", "graph_base.json") | |||
| with open(graph_base_path, 'r') as load_f: | |||
| graph = json.load(load_f) | |||
| graph_log_generator.generate_log(test_file_name, graph) | |||
| @@ -18,10 +18,11 @@ import time | |||
| import numpy as np | |||
| from PIL import Image | |||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| from .log_generator import LogGenerator | |||
| class ImagesLogGenerator(LogGenerator): | |||
| """ | |||
| @@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator): | |||
| images_metadata.append(image_metadata) | |||
| images_values.update({step: image_tensor}) | |||
| values = dict( | |||
| wall_time=wall_time, | |||
| step=step, | |||
| image=image_tensor, | |||
| tag=tag_name | |||
| ) | |||
| values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name) | |||
| self._write_log_one_step(file_path, values) | |||
| @@ -17,7 +17,7 @@ | |||
| import struct | |||
| from abc import abstractmethod | |||
| from tests.st.func.datavisual.utils import crc32 | |||
| from ...utils import crc32 | |||
| class LogGenerator: | |||
| @@ -16,10 +16,11 @@ | |||
| import time | |||
| import numpy as np | |||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| from .log_generator import LogGenerator | |||
| class ScalarsLogGenerator(LogGenerator): | |||
| """ | |||
| @@ -19,13 +19,12 @@ import json | |||
| import os | |||
| import time | |||
| from tests.st.func.datavisual.constants import SUMMARY_PREFIX | |||
| from tests.st.func.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator | |||
| from tests.st.func.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||
| from tests.st.func.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .log_generators.graph_log_generator import GraphLogGenerator | |||
| from .log_generators.images_log_generator import ImagesLogGenerator | |||
| from .log_generators.scalars_log_generator import ScalarsLogGenerator | |||
| log_generators = { | |||
| PluginNameEnum.GRAPH.value: GraphLogGenerator(), | |||
| PluginNameEnum.IMAGE.value: ImagesLogGenerator(), | |||
| @@ -35,10 +34,12 @@ log_generators = { | |||
| class LogOperations: | |||
| """Log Operations.""" | |||
| def __init__(self): | |||
| self._step_num = 3 | |||
| self._tag_num = 2 | |||
| self._time_count = 0 | |||
| self._graph_base_path = os.path.join(os.path.dirname(__file__), "log_generators", "graph_base.json") | |||
| def _get_steps(self): | |||
| """Get steps.""" | |||
| @@ -61,9 +62,7 @@ class LogOperations: | |||
| metadata_dict["plugins"].update({plugin_name: list()}) | |||
| log_generator = log_generators.get(plugin_name) | |||
| if plugin_name == PluginNameEnum.GRAPH.value: | |||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||
| with open(graph_base_path, 'r') as load_f: | |||
| with open(self._graph_base_path, 'r') as load_f: | |||
| graph_dict = json.load(load_f) | |||
| values = log_generator.generate_log(file_path, graph_dict) | |||
| metadata_dict["actual_values"].update({plugin_name: values}) | |||
| @@ -82,13 +81,13 @@ class LogOperations: | |||
| self._time_count += 1 | |||
| return metadata_dict | |||
| def create_summary_logs(self, summary_base_dir, summary_dir_num, start_index=0): | |||
| def create_summary_logs(self, summary_base_dir, summary_dir_num, dir_prefix, start_index=0): | |||
| """Create summary logs in summary_base_dir.""" | |||
| summary_metadata = dict() | |||
| steps_list = self._get_steps() | |||
| tag_name_list = self._get_tags() | |||
| for i in range(start_index, summary_dir_num + start_index): | |||
| log_dir = os.path.join(summary_base_dir, f'{SUMMARY_PREFIX}{i}') | |||
| log_dir = os.path.join(summary_base_dir, f'{dir_prefix}{i}') | |||
| os.makedirs(log_dir) | |||
| train_id = log_dir.replace(summary_base_dir, ".") | |||
| @@ -120,3 +119,47 @@ class LogOperations: | |||
| metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list) | |||
| return {train_id: metadata_dict} | |||
| def generate_log(self, plugin_name, log_dir, log_settings=None, valid=True): | |||
| """ | |||
| Generate log for ut. | |||
| Args: | |||
| plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'. | |||
| log_dir (str): Log path to write log. | |||
| log_settings (dict): Info about the log, e.g.: | |||
| { | |||
| current_time (int): Timestamp in summary file name, not necessary. | |||
| graph_base_path (str): Path of graph_bas.json, necessary for `graph`. | |||
| steps (list[int]): Steps for `image` and `scalar`, default is [1]. | |||
| tag (str): Tag name, default is 'default_tag'. | |||
| } | |||
| valid (bool): If true, summary name will be valid. | |||
| Returns: | |||
| str, Summary log path. | |||
| """ | |||
| if log_settings is None: | |||
| log_settings = dict() | |||
| current_time = log_settings.get('time', int(time.time())) | |||
| current_time = int(current_time) | |||
| log_generator = log_generators.get(plugin_name) | |||
| if valid: | |||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time))) | |||
| else: | |||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time))) | |||
| if plugin_name == PluginNameEnum.GRAPH.value: | |||
| with open(self._graph_base_path, 'r') as load_f: | |||
| graph_dict = json.load(load_f) | |||
| graph_dict = log_generator.generate_log(temp_path, graph_dict) | |||
| return temp_path, graph_dict | |||
| steps_list = log_settings.get('steps', [1]) | |||
| tag_name = log_settings.get('tag', 'default_tag') | |||
| metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name) | |||
| return temp_path, metadata, values | |||