| @@ -0,0 +1,26 @@ | |||||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| **What type of PR is this?** | |||||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > /kind bug | |||||
| > /kind task | |||||
| > /kind feature | |||||
| **What does this PR do / why do we need it**: | |||||
| **Which issue(s) this PR fixes**: | |||||
| <!-- | |||||
| *Automatically closes linked issue when PR is merged. | |||||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||||
| --> | |||||
| Fixes # | |||||
| **Special notes for your reviewers**: | |||||
| @@ -0,0 +1,19 @@ | |||||
| --- | |||||
| name: RFC | |||||
| about: Use this template for the new feature or enhancement | |||||
| labels: kind/feature or kind/enhancement | |||||
| --- | |||||
| ## Background | |||||
| - Describe the status of the problem you wish to solve | |||||
| - Attach the relevant issue if have | |||||
| ## Introduction | |||||
| - Describe the general solution, design and/or pseudo-code | |||||
| ## Trail | |||||
| | No. | Task Description | Related Issue(URL) | | |||||
| | --- | ---------------- | ------------------ | | |||||
| | 1 | | | | |||||
| | 2 | | | | |||||
| @@ -0,0 +1,43 @@ | |||||
| --- | |||||
| name: Bug Report | |||||
| about: Use this template for reporting a bug | |||||
| labels: kind/bug | |||||
| --- | |||||
| <!-- Thanks for sending an issue! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| ## Environment | |||||
| ### Hardware Environment(`Ascend`/`GPU`/`CPU`): | |||||
| > Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > `/device ascend`</br> | |||||
| > `/device gpu`</br> | |||||
| > `/device cpu`</br> | |||||
| ### Software Environment: | |||||
| - **MindSpore version (source or binary)**: | |||||
| - **Python version (e.g., Python 3.7.5)**: | |||||
| - **OS platform and distribution (e.g., Linux Ubuntu 16.04)**: | |||||
| - **GCC/Compiler version (if compiled from source)**: | |||||
| ## Describe the current behavior | |||||
| ## Describe the expected behavior | |||||
| ## Steps to reproduce the issue | |||||
| 1. | |||||
| 2. | |||||
| 3. | |||||
| ## Related log / screenshot | |||||
| ## Special notes for this issue | |||||
| @@ -0,0 +1,19 @@ | |||||
| --- | |||||
| name: Task | |||||
| about: Use this template for task tracking | |||||
| labels: kind/task | |||||
| --- | |||||
| ## Task Description | |||||
| ## Task Goal | |||||
| ## Sub Task | |||||
| | No. | Task Description | Issue ID | | |||||
| | --- | ---------------- | -------- | | |||||
| | 1 | | | | |||||
| | 2 | | | | |||||
| @@ -0,0 +1,24 @@ | |||||
| <!-- Thanks for sending a pull request! Here are some tips for you: | |||||
| If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md | |||||
| --> | |||||
| **What type of PR is this?** | |||||
| > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: | |||||
| > | |||||
| > `/kind bug`</br> | |||||
| > `/kind task`</br> | |||||
| > `/kind feature`</br> | |||||
| **What does this PR do / why do we need it**: | |||||
| **Which issue(s) this PR fixes**: | |||||
| <!-- | |||||
| *Automatically closes linked issue when PR is merged. | |||||
| Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. | |||||
| --> | |||||
| Fixes # | |||||
| **Special notes for your reviewers**: | |||||
| @@ -18,19 +18,21 @@ This file is used to define the basic graph. | |||||
| import copy | import copy | ||||
| import time | import time | ||||
| from enum import Enum | |||||
| from mindinsight.datavisual.common.log import logger | from mindinsight.datavisual.common.log import logger | ||||
| from mindinsight.datavisual.common import exceptions | from mindinsight.datavisual.common import exceptions | ||||
| from .node import NodeTypeEnum | from .node import NodeTypeEnum | ||||
| from .node import Node | from .node import Node | ||||
| class EdgeTypeEnum: | |||||
| class EdgeTypeEnum(Enum): | |||||
| """Node edge type enum.""" | """Node edge type enum.""" | ||||
| control = 'control' | |||||
| data = 'data' | |||||
| CONTROL = 'control' | |||||
| DATA = 'data' | |||||
| class DataTypeEnum: | |||||
| class DataTypeEnum(Enum): | |||||
| """Data type enum.""" | """Data type enum.""" | ||||
| DT_TENSOR = 13 | DT_TENSOR = 13 | ||||
| @@ -292,70 +294,65 @@ class Graph: | |||||
| output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value | output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value | ||||
| node.update_output({dst_name: output_attr}) | node.update_output({dst_name: output_attr}) | ||||
| def _calc_polymeric_input_output(self): | |||||
| def _update_polymeric_input_output(self): | |||||
| """Calc polymeric input and output after build polymeric node.""" | """Calc polymeric input and output after build polymeric node.""" | ||||
| for name, node in self._normal_nodes.items(): | |||||
| polymeric_input = {} | |||||
| for src_name in node.input: | |||||
| src_node = self._polymeric_nodes.get(src_name) | |||||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||||
| src_name = src_name if not src_node else src_node.polymeric_scope_name | |||||
| output_name = self._calc_dummy_node_name(name, src_name) | |||||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| continue | |||||
| if not src_node: | |||||
| continue | |||||
| if not node.name_scope and src_node.name_scope: | |||||
| # if current node is in first layer, and the src node is not in | |||||
| # the first layer, the src node will not be the polymeric input of current node. | |||||
| continue | |||||
| if node.name_scope == src_node.name_scope \ | |||||
| or node.name_scope.startswith(src_node.name_scope): | |||||
| polymeric_input.update( | |||||
| {src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| for node in self._normal_nodes.values(): | |||||
| polymeric_input = self._calc_polymeric_attr(node, 'input') | |||||
| node.update_polymeric_input(polymeric_input) | node.update_polymeric_input(polymeric_input) | ||||
| polymeric_output = {} | |||||
| for dst_name in node.output: | |||||
| dst_node = self._polymeric_nodes.get(dst_name) | |||||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||||
| dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name | |||||
| output_name = self._calc_dummy_node_name(name, dst_name) | |||||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| continue | |||||
| if not dst_node: | |||||
| continue | |||||
| if not node.name_scope and dst_node.name_scope: | |||||
| continue | |||||
| if node.name_scope == dst_node.name_scope \ | |||||
| or node.name_scope.startswith(dst_node.name_scope): | |||||
| polymeric_output.update( | |||||
| {dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| polymeric_output = self._calc_polymeric_attr(node, 'output') | |||||
| node.update_polymeric_output(polymeric_output) | node.update_polymeric_output(polymeric_output) | ||||
| for name, node in self._polymeric_nodes.items(): | for name, node in self._polymeric_nodes.items(): | ||||
| polymeric_input = {} | polymeric_input = {} | ||||
| for src_name in node.input: | for src_name in node.input: | ||||
| output_name = self._calc_dummy_node_name(name, src_name) | output_name = self._calc_dummy_node_name(name, src_name) | ||||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||||
| node.update_polymeric_input(polymeric_input) | node.update_polymeric_input(polymeric_input) | ||||
| polymeric_output = {} | polymeric_output = {} | ||||
| for dst_name in node.output: | for dst_name in node.output: | ||||
| polymeric_output = {} | polymeric_output = {} | ||||
| output_name = self._calc_dummy_node_name(name, dst_name) | output_name = self._calc_dummy_node_name(name, dst_name) | ||||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||||
| node.update_polymeric_output(polymeric_output) | node.update_polymeric_output(polymeric_output) | ||||
| def _calc_polymeric_attr(self, node, attr): | |||||
| """ | |||||
| Calc polymeric input or polymeric output after build polymeric node. | |||||
| Args: | |||||
| node (Node): Computes the polymeric input for a given node. | |||||
| attr (str): The polymeric attr, optional value is `input` or `output`. | |||||
| Returns: | |||||
| dict, return polymeric input or polymeric output of the given node. | |||||
| """ | |||||
| polymeric_attr = {} | |||||
| for node_name in getattr(node, attr): | |||||
| polymeric_node = self._polymeric_nodes.get(node_name) | |||||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||||
| node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name | |||||
| dummy_node_name = self._calc_dummy_node_name(node.name, node_name) | |||||
| polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||||
| continue | |||||
| if not polymeric_node: | |||||
| continue | |||||
| if not node.name_scope and polymeric_node.name_scope: | |||||
| # If current node is in top-level layer, and the polymeric_node node is not in | |||||
| # the top-level layer, the polymeric node will not be the polymeric input | |||||
| # or polymeric output of current node. | |||||
| continue | |||||
| if node.name_scope == polymeric_node.name_scope \ | |||||
| or node.name_scope.startswith(polymeric_node.name_scope + '/'): | |||||
| polymeric_attr.update( | |||||
| {polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||||
| return polymeric_attr | |||||
| def _calc_dummy_node_name(self, current_node_name, other_node_name): | def _calc_dummy_node_name(self, current_node_name, other_node_name): | ||||
| """ | """ | ||||
| Calc dummy node name. | Calc dummy node name. | ||||
| @@ -39,7 +39,7 @@ class MSGraph(Graph): | |||||
| self._build_leaf_nodes(graph_proto) | self._build_leaf_nodes(graph_proto) | ||||
| self._build_polymeric_nodes() | self._build_polymeric_nodes() | ||||
| self._build_name_scope_nodes() | self._build_name_scope_nodes() | ||||
| self._calc_polymeric_input_output() | |||||
| self._update_polymeric_input_output() | |||||
| logger.info("Build graph end, normal node count: %s, polymeric node " | logger.info("Build graph end, normal node count: %s, polymeric node " | ||||
| "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) | "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) | ||||
| @@ -90,9 +90,9 @@ class MSGraph(Graph): | |||||
| node_name = leaf_node_id_map_name[node_def.name] | node_name = leaf_node_id_map_name[node_def.name] | ||||
| node = self._leaf_nodes[node_name] | node = self._leaf_nodes[node_name] | ||||
| for input_def in node_def.input: | for input_def in node_def.input: | ||||
| edge_type = EdgeTypeEnum.data | |||||
| edge_type = EdgeTypeEnum.DATA.value | |||||
| if input_def.type == "CONTROL_EDGE": | if input_def.type == "CONTROL_EDGE": | ||||
| edge_type = EdgeTypeEnum.control | |||||
| edge_type = EdgeTypeEnum.CONTROL.value | |||||
| if const_nodes_map.get(input_def.name): | if const_nodes_map.get(input_def.name): | ||||
| const_node = copy.deepcopy(const_nodes_map[input_def.name]) | const_node = copy.deepcopy(const_nodes_map[input_def.name]) | ||||
| @@ -218,7 +218,7 @@ class MSGraph(Graph): | |||||
| node = Node(name=const.key, node_id=const_node_id) | node = Node(name=const.key, node_id=const_node_id) | ||||
| node.node_type = NodeTypeEnum.CONST.value | node.node_type = NodeTypeEnum.CONST.value | ||||
| node.update_attr({const.key: str(const.value)}) | node.update_attr({const.key: str(const.value)}) | ||||
| if const.value.dtype == DataTypeEnum.DT_TENSOR: | |||||
| if const.value.dtype == DataTypeEnum.DT_TENSOR.value: | |||||
| shape = [] | shape = [] | ||||
| for dim in const.value.tensor_val.dims: | for dim in const.value.tensor_val.dims: | ||||
| shape.append(dim) | shape.append(dim) | ||||
| @@ -172,7 +172,7 @@ class Node: | |||||
| Args: | Args: | ||||
| polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: | polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: | ||||
| {'edge_type': EdgeTypeEnum.data}}). | |||||
| {'edge_type': EdgeTypeEnum.DATA.value}}). | |||||
| """ | """ | ||||
| self._polymeric_output.update(polymeric_output) | self._polymeric_output.update(polymeric_output) | ||||
| @@ -168,7 +168,7 @@ class TrainLineage(Callback): | |||||
| train_lineage = AnalyzeObject.get_network_args( | train_lineage = AnalyzeObject.get_network_args( | ||||
| run_context_args, train_lineage | run_context_args, train_lineage | ||||
| ) | ) | ||||
| train_dataset = run_context_args.get('train_dataset') | train_dataset = run_context_args.get('train_dataset') | ||||
| callbacks = run_context_args.get('list_callback') | callbacks = run_context_args.get('list_callback') | ||||
| list_callback = getattr(callbacks, '_callbacks', []) | list_callback = getattr(callbacks, '_callbacks', []) | ||||
| @@ -601,7 +601,7 @@ class AnalyzeObject: | |||||
| loss = None | loss = None | ||||
| else: | else: | ||||
| loss = run_context_args.get('net_outputs') | loss = run_context_args.get('net_outputs') | ||||
| if loss: | if loss: | ||||
| log.info('Calculating loss...') | log.info('Calculating loss...') | ||||
| loss_numpy = loss.asnumpy() | loss_numpy = loss.asnumpy() | ||||
| @@ -610,7 +610,7 @@ class AnalyzeObject: | |||||
| train_lineage[Metadata.loss] = loss | train_lineage[Metadata.loss] = loss | ||||
| else: | else: | ||||
| train_lineage[Metadata.loss] = None | train_lineage[Metadata.loss] = None | ||||
| # Analyze classname of optimizer, loss function and training network. | # Analyze classname of optimizer, loss function and training network. | ||||
| train_lineage[Metadata.optimizer] = type(optimizer).__name__ \ | train_lineage[Metadata.optimizer] = type(optimizer).__name__ \ | ||||
| if optimizer else None | if optimizer else None | ||||
| @@ -18,13 +18,10 @@ Description: This file is used for some common util. | |||||
| import os | import os | ||||
| import shutil | import shutil | ||||
| from unittest.mock import Mock | from unittest.mock import Mock | ||||
| import pytest | import pytest | ||||
| from flask import Response | from flask import Response | ||||
| from tests.st.func.datavisual import constants | |||||
| from tests.st.func.datavisual.utils.log_operations import LogOperations | |||||
| from tests.st.func.datavisual.utils.utils import check_loading_done | |||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | from mindinsight.datavisual.data_transform.data_manager import DataManager | ||||
| @@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat | |||||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | ||||
| from mindinsight.datavisual.utils import tools | from mindinsight.datavisual.utils import tools | ||||
| from ....utils.log_operations import LogOperations | |||||
| from ....utils.tools import check_loading_done | |||||
| from . import constants | |||||
| from . import globals as gbl | |||||
| summaries_metadata = None | summaries_metadata = None | ||||
| mock_data_manager = None | mock_data_manager = None | ||||
| summary_base_dir = constants.SUMMARY_BASE_DIR | summary_base_dir = constants.SUMMARY_BASE_DIR | ||||
| @@ -55,17 +57,21 @@ def init_summary_logs(): | |||||
| os.mkdir(summary_base_dir, mode=mode) | os.mkdir(summary_base_dir, mode=mode) | ||||
| global summaries_metadata, mock_data_manager | global summaries_metadata, mock_data_manager | ||||
| log_operations = LogOperations() | log_operations = LogOperations() | ||||
| summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST) | |||||
| summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, | |||||
| constants.SUMMARY_DIR_PREFIX) | |||||
| mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) | mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) | ||||
| mock_data_manager.start_load_data(reload_interval=0) | mock_data_manager.start_load_data(reload_interval=0) | ||||
| check_loading_done(mock_data_manager) | check_loading_done(mock_data_manager) | ||||
| summaries_metadata.update(log_operations.create_summary_logs( | |||||
| summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST)) | |||||
| summaries_metadata.update(log_operations.create_multiple_logs( | |||||
| summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM)) | |||||
| summaries_metadata.update(log_operations.create_reservoir_log( | |||||
| summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) | |||||
| summaries_metadata.update( | |||||
| log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, | |||||
| constants.SUMMARY_DIR_NUM_FIRST)) | |||||
| summaries_metadata.update( | |||||
| log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME, | |||||
| constants.MULTIPLE_LOG_NUM)) | |||||
| summaries_metadata.update( | |||||
| log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME, | |||||
| constants.RESERVOIR_STEP_NUM)) | |||||
| mock_data_manager.start_load_data(reload_interval=0) | mock_data_manager.start_load_data(reload_interval=0) | ||||
| # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. | # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. | ||||
| @@ -73,7 +79,7 @@ def init_summary_logs(): | |||||
| # Maximum number of loads is `MAX_DATA_LOADER_SIZE`. | # Maximum number of loads is `MAX_DATA_LOADER_SIZE`. | ||||
| for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): | for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): | ||||
| summaries_metadata.pop("./%s%d" % (constants.SUMMARY_PREFIX, i)) | |||||
| summaries_metadata.pop("./%s%d" % (constants.SUMMARY_DIR_PREFIX, i)) | |||||
| yield | yield | ||||
| finally: | finally: | ||||
| @@ -16,7 +16,7 @@ | |||||
| import tempfile | import tempfile | ||||
| SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name | SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name | ||||
| SUMMARY_PREFIX = "summary" | |||||
| SUMMARY_DIR_PREFIX = "summary" | |||||
| SUMMARY_DIR_NUM_FIRST = 5 | SUMMARY_DIR_NUM_FIRST = 5 | ||||
| SUMMARY_DIR_NUM_SECOND = 11 | SUMMARY_DIR_NUM_SECOND = 11 | ||||
| @@ -19,11 +19,11 @@ Usage: | |||||
| pytest tests/st/func/datavisual | pytest tests/st/func/datavisual | ||||
| """ | """ | ||||
| import os | import os | ||||
| import json | |||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from .. import globals as gbl | |||||
| from .....utils.tools import get_url, compare_result_with_file | |||||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' | BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' | ||||
| @@ -33,12 +33,6 @@ class TestQueryNodes: | |||||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | ||||
| def compare_result_with_file(self, result, filename): | |||||
| """Compare result with file which contain the expected results.""" | |||||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||||
| expected_results = json.load(fp) | |||||
| assert result == expected_results | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @@ -65,4 +59,5 @@ class TestQueryNodes: | |||||
| url = get_url(BASE_URL, params) | url = get_url(BASE_URL, params) | ||||
| response = client.get(url) | response = client.get(url) | ||||
| assert response.status_code == 200 | assert response.status_code == 200 | ||||
| self.compare_result_with_file(response.get_json(), result_file) | |||||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(response.get_json(), file_path) | |||||
| @@ -19,12 +19,11 @@ Usage: | |||||
| pytest tests/st/func/datavisual | pytest tests/st/func/datavisual | ||||
| """ | """ | ||||
| import os | import os | ||||
| import json | |||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from .. import globals as gbl | |||||
| from .....utils.tools import get_url, compare_result_with_file | |||||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' | BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' | ||||
| @@ -34,12 +33,6 @@ class TestQuerySingleNode: | |||||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | ||||
| def compare_result_with_file(self, result, filename): | |||||
| """Compare result with file which contain the expected results.""" | |||||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||||
| expected_results = json.load(fp) | |||||
| assert result == expected_results | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @@ -59,4 +52,5 @@ class TestQuerySingleNode: | |||||
| url = get_url(BASE_URL, params) | url = get_url(BASE_URL, params) | ||||
| response = client.get(url) | response = client.get(url) | ||||
| assert response.status_code == 200 | assert response.status_code == 200 | ||||
| self.compare_result_with_file(response.get_json(), result_file) | |||||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(response.get_json(), file_path) | |||||
| @@ -19,25 +19,20 @@ Usage: | |||||
| pytest tests/st/func/datavisual | pytest tests/st/func/datavisual | ||||
| """ | """ | ||||
| import os | import os | ||||
| import json | |||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from .. import globals as gbl | |||||
| from .....utils.tools import get_url, compare_result_with_file | |||||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' | BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' | ||||
| class TestSearchNodes: | class TestSearchNodes: | ||||
| """Test search nodes restful APIs.""" | |||||
| """Test searching nodes restful APIs.""" | |||||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | ||||
| def compare_result_with_file(self, result, filename): | |||||
| """Compare result with file which contain the expected results.""" | |||||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||||
| expected_results = json.load(fp) | |||||
| assert result == expected_results | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @@ -58,4 +53,5 @@ class TestSearchNodes: | |||||
| url = get_url(BASE_URL, params) | url = get_url(BASE_URL, params) | ||||
| response = client.get(url) | response = client.get(url) | ||||
| assert response.status_code == 200 | assert response.status_code == 200 | ||||
| self.compare_result_with_file(response.get_json(), result_file) | |||||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(response.get_json(), file_path) | |||||
| @@ -20,13 +20,13 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID | |||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_url | |||||
| from .. import globals as gbl | |||||
| from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID | |||||
| BASE_URL = '/v1/mindinsight/datavisual/image/metadata' | BASE_URL = '/v1/mindinsight/datavisual/image/metadata' | ||||
| @@ -20,11 +20,11 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_image_tensor_from_bytes, get_url | |||||
| from .. import globals as gbl | |||||
| BASE_URL = '/v1/mindinsight/datavisual/image/single-image' | BASE_URL = '/v1/mindinsight/datavisual/image/single-image' | ||||
| @@ -19,11 +19,12 @@ Usage: | |||||
| pytest tests/st/func/datavisual | pytest tests/st/func/datavisual | ||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_url | |||||
| from .. import globals as gbl | |||||
| BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' | BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' | ||||
| @@ -20,11 +20,11 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_url | |||||
| from .. import globals as gbl | |||||
| BASE_URL = '/v1/mindinsight/datavisual/plugins' | BASE_URL = '/v1/mindinsight/datavisual/plugins' | ||||
| @@ -19,11 +19,12 @@ Usage: | |||||
| pytest tests/st/func/datavisual | pytest tests/st/func/datavisual | ||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_url | |||||
| from .. import globals as gbl | |||||
| BASE_URL = '/v1/mindinsight/datavisual/single-job' | BASE_URL = '/v1/mindinsight/datavisual/single-job' | ||||
| @@ -20,8 +20,8 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.constants import SUMMARY_DIR_NUM | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from ..constants import SUMMARY_DIR_NUM | |||||
| from .....utils.tools import get_url | |||||
| BASE_URL = '/v1/mindinsight/datavisual/train-jobs' | BASE_URL = '/v1/mindinsight/datavisual/train-jobs' | ||||
| @@ -1,79 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Log generator for graph.""" | |||||
| import json | |||||
| import os | |||||
| import time | |||||
| from google.protobuf import json_format | |||||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||||
| class GraphLogGenerator(LogGenerator): | |||||
| """ | |||||
| Log generator for graph. | |||||
| This is a log generator writing graph. User can use it to generate fake | |||||
| summary logs about graph. | |||||
| """ | |||||
| def generate_log(self, file_path, graph_dict): | |||||
| """ | |||||
| Generate log for external calls. | |||||
| Args: | |||||
| file_path (str): Path to write logs. | |||||
| graph_dict (dict): A dict consists of graph node information. | |||||
| Returns: | |||||
| dict, generated scalar metadata. | |||||
| """ | |||||
| graph_event = self.generate_event(dict(graph=graph_dict)) | |||||
| self._write_log_from_event(file_path, graph_event) | |||||
| return graph_dict | |||||
| def generate_event(self, values): | |||||
| """ | |||||
| Method for generating graph event. | |||||
| Args: | |||||
| values (dict): Graph values. e.g. {'graph': graph_dict}. | |||||
| Returns: | |||||
| summary_pb2.Event. | |||||
| """ | |||||
| graph_json = { | |||||
| 'wall_time': time.time(), | |||||
| 'graph_def': values.get('graph'), | |||||
| } | |||||
| graph_event = json_format.Parse(json.dumps(graph_json), summary_pb2.Event()) | |||||
| return graph_event | |||||
| if __name__ == "__main__": | |||||
| graph_log_generator = GraphLogGenerator() | |||||
| test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) | |||||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") | |||||
| with open(graph_base_path, 'r') as load_f: | |||||
| graph = json.load(load_f) | |||||
| graph_log_generator.generate_log(test_file_name, graph) | |||||
| @@ -20,11 +20,11 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_url | |||||
| from .. import globals as gbl | |||||
| TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | ||||
| PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | ||||
| METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | ||||
| @@ -20,11 +20,11 @@ Usage: | |||||
| """ | """ | ||||
| import pytest | import pytest | ||||
| from tests.st.func.datavisual.utils import globals as gbl | |||||
| from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .....utils.tools import get_image_tensor_from_bytes, get_url | |||||
| from .. import globals as gbl | |||||
| TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' | ||||
| PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' | ||||
| METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' | ||||
| @@ -26,11 +26,101 @@ from unittest import TestCase | |||||
| import pytest | import pytest | ||||
| from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \ | |||||
| LineageSearchConditionParamError, LineageFileNotFoundError | |||||
| from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH | |||||
| from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError, | |||||
| LineageParamTypeError, LineageParamValueError, | |||||
| LineageSearchConditionParamError) | |||||
| from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | |||||
| LINEAGE_INFO_RUN1 = { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'hyper_parameters': { | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'epoch': 14, | |||||
| 'parallel_mode': 'stand_alone', | |||||
| 'device_num': 2, | |||||
| 'batch_size': 32 | |||||
| }, | |||||
| 'algorithm': { | |||||
| 'network': 'ResNet' | |||||
| }, | |||||
| 'train_dataset': { | |||||
| 'train_dataset_size': 731 | |||||
| }, | |||||
| 'valid_dataset': { | |||||
| 'valid_dataset_size': 10240 | |||||
| }, | |||||
| 'model': { | |||||
| 'path': '{"ckpt": "' | |||||
| + BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}', | |||||
| 'size': 64 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH | |||||
| } | |||||
| LINEAGE_FILTRATION_EXCEPT_RUN = { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 1024, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 10, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 64, | |||||
| 'metric': {}, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| } | |||||
| LINEAGE_FILTRATION_RUN1 = { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 731, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': None, | |||||
| 'model_size': 64, | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| } | |||||
| LINEAGE_FILTRATION_RUN2 = { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||||
| 'loss_function': None, | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': None, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': None, | |||||
| 'optimizer': None, | |||||
| 'learning_rate': None, | |||||
| 'epoch': None, | |||||
| 'batch_size': None, | |||||
| 'loss': None, | |||||
| 'model_size': None, | |||||
| 'metric': { | |||||
| 'accuracy': 2.7800000000000002 | |||||
| }, | |||||
| 'dataset_graph': {}, | |||||
| 'dataset_mark': 3 | |||||
| } | |||||
| @pytest.mark.usefixtures("create_summary_dir") | @pytest.mark.usefixtures("create_summary_dir") | ||||
| @@ -67,36 +157,7 @@ class TestModelApi(TestCase): | |||||
| total_res = get_summary_lineage(SUMMARY_DIR) | total_res = get_summary_lineage(SUMMARY_DIR) | ||||
| partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters']) | partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters']) | ||||
| partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm']) | partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm']) | ||||
| expect_total_res = { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'hyper_parameters': { | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'epoch': 14, | |||||
| 'parallel_mode': 'stand_alone', | |||||
| 'device_num': 2, | |||||
| 'batch_size': 32 | |||||
| }, | |||||
| 'algorithm': { | |||||
| 'network': 'ResNet' | |||||
| }, | |||||
| 'train_dataset': { | |||||
| 'train_dataset_size': 731 | |||||
| }, | |||||
| 'valid_dataset': { | |||||
| 'valid_dataset_size': 10240 | |||||
| }, | |||||
| 'model': { | |||||
| 'path': '{"ckpt": "' | |||||
| + BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}', | |||||
| 'size': 64 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH | |||||
| } | |||||
| expect_total_res = LINEAGE_INFO_RUN1 | |||||
| expect_partial_res1 = { | expect_partial_res1 = { | ||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | ||||
| 'hyper_parameters': { | 'hyper_parameters': { | ||||
| @@ -139,7 +200,7 @@ class TestModelApi(TestCase): | |||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| def test_get_summary_lineage_exception(self): | |||||
| def test_get_summary_lineage_exception_1(self): | |||||
| """Test the interface of get_summary_lineage with exception.""" | """Test the interface of get_summary_lineage with exception.""" | ||||
| # summary path does not exist | # summary path does not exist | ||||
| self.assertRaisesRegex( | self.assertRaisesRegex( | ||||
| @@ -183,6 +244,14 @@ class TestModelApi(TestCase): | |||||
| keys=None | keys=None | ||||
| ) | ) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_get_summary_lineage_exception_2(self): | |||||
| """Test the interface of get_summary_lineage with exception.""" | |||||
| # keys is invalid | # keys is invalid | ||||
| self.assertRaisesRegex( | self.assertRaisesRegex( | ||||
| LineageParamValueError, | LineageParamValueError, | ||||
| @@ -250,64 +319,9 @@ class TestModelApi(TestCase): | |||||
| """Test the interface of filter_summary_lineage.""" | """Test the interface of filter_summary_lineage.""" | ||||
| expect_result = { | expect_result = { | ||||
| 'object': [ | 'object': [ | ||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 1024, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 10, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 64, | |||||
| 'metric': {}, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 731, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': None, | |||||
| 'model_size': 64, | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||||
| 'loss_function': None, | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': None, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': None, | |||||
| 'optimizer': None, | |||||
| 'learning_rate': None, | |||||
| 'epoch': None, | |||||
| 'batch_size': None, | |||||
| 'loss': None, | |||||
| 'model_size': None, | |||||
| 'metric': { | |||||
| 'accuracy': 2.7800000000000002 | |||||
| }, | |||||
| 'dataset_graph': {}, | |||||
| 'dataset_mark': 3 | |||||
| } | |||||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||||
| LINEAGE_FILTRATION_RUN1, | |||||
| LINEAGE_FILTRATION_RUN2 | |||||
| ], | ], | ||||
| 'count': 3 | 'count': 3 | ||||
| } | } | ||||
| @@ -357,46 +371,8 @@ class TestModelApi(TestCase): | |||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'object': [ | 'object': [ | ||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||||
| 'loss_function': None, | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': None, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': None, | |||||
| 'optimizer': None, | |||||
| 'learning_rate': None, | |||||
| 'epoch': None, | |||||
| 'batch_size': None, | |||||
| 'loss': None, | |||||
| 'model_size': None, | |||||
| 'metric': { | |||||
| 'accuracy': 2.7800000000000002 | |||||
| }, | |||||
| 'dataset_graph': {}, | |||||
| 'dataset_mark': 3 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 731, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': None, | |||||
| 'model_size': 64, | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| } | |||||
| LINEAGE_FILTRATION_RUN2, | |||||
| LINEAGE_FILTRATION_RUN1 | |||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| @@ -432,46 +408,8 @@ class TestModelApi(TestCase): | |||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'object': [ | 'object': [ | ||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), | |||||
| 'loss_function': None, | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': None, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': None, | |||||
| 'optimizer': None, | |||||
| 'learning_rate': None, | |||||
| 'epoch': None, | |||||
| 'batch_size': None, | |||||
| 'loss': None, | |||||
| 'model_size': None, | |||||
| 'metric': { | |||||
| 'accuracy': 2.7800000000000002 | |||||
| }, | |||||
| 'dataset_graph': {}, | |||||
| 'dataset_mark': 3 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 731, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': None, | |||||
| 'model_size': 64, | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| } | |||||
| LINEAGE_FILTRATION_RUN2, | |||||
| LINEAGE_FILTRATION_RUN1 | |||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| @@ -498,44 +436,8 @@ class TestModelApi(TestCase): | |||||
| } | } | ||||
| expect_result = { | expect_result = { | ||||
| 'object': [ | 'object': [ | ||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 1024, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 10, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 64, | |||||
| 'metric': {}, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 731, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 10240, | |||||
| 'network': 'ResNet', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': None, | |||||
| 'model_size': 64, | |||||
| 'metric': { | |||||
| 'accuracy': 0.78 | |||||
| }, | |||||
| 'dataset_graph': DATASET_GRAPH, | |||||
| 'dataset_mark': 2 | |||||
| } | |||||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||||
| LINEAGE_FILTRATION_RUN1 | |||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| } | } | ||||
| @@ -674,6 +576,14 @@ class TestModelApi(TestCase): | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_filter_summary_lineage_exception_3(self): | |||||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||||
| # the condition of offset is invalid | # the condition of offset is invalid | ||||
| search_condition = { | search_condition = { | ||||
| 'offset': 1.0 | 'offset': 1.0 | ||||
| @@ -712,6 +622,14 @@ class TestModelApi(TestCase): | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_filter_summary_lineage_exception_4(self): | |||||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||||
| # the sorted_type not supported | # the sorted_type not supported | ||||
| search_condition = { | search_condition = { | ||||
| 'sorted_name': 'summary_dir', | 'sorted_name': 'summary_dir', | ||||
| @@ -753,6 +671,14 @@ class TestModelApi(TestCase): | |||||
| search_condition | search_condition | ||||
| ) | ) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_filter_summary_lineage_exception_5(self): | |||||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | |||||
| # the summary dir is invalid in search condition | # the summary dir is invalid in search condition | ||||
| search_condition = { | search_condition = { | ||||
| 'summary_dir': { | 'summary_dir': { | ||||
| @@ -811,7 +737,7 @@ class TestModelApi(TestCase): | |||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| def test_filter_summary_lineage_exception_3(self): | |||||
| def test_filter_summary_lineage_exception_6(self): | |||||
| """Test the abnormal execution of the filter_summary_lineage interface.""" | """Test the abnormal execution of the filter_summary_lineage interface.""" | ||||
| # gt > lt | # gt > lt | ||||
| search_condition1 = { | search_condition1 = { | ||||
| @@ -21,7 +21,8 @@ import tempfile | |||||
| import pytest | import pytest | ||||
| from .collection.model import mindspore | |||||
| from ....utils import mindspore | |||||
| from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE | |||||
| sys.modules['mindspore'] = mindspore | sys.modules['mindspore'] = mindspore | ||||
| @@ -32,52 +33,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run') | |||||
| COLLECTION_MODULE = 'TestModelLineage' | COLLECTION_MODULE = 'TestModelLineage' | ||||
| API_MODULE = 'TestModelApi' | API_MODULE = 'TestModelApi' | ||||
| DATASET_GRAPH = { | |||||
| 'op_type': 'BatchDataset', | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'num_parallel_workers': None, | |||||
| 'drop_remainder': True, | |||||
| 'batch_size': 10, | |||||
| 'children': [ | |||||
| { | |||||
| 'op_type': 'MapDataset', | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'num_parallel_workers': None, | |||||
| 'input_columns': [ | |||||
| 'label' | |||||
| ], | |||||
| 'output_columns': [ | |||||
| None | |||||
| ], | |||||
| 'operations': [ | |||||
| { | |||||
| 'tensor_op_module': 'minddata.transforms.c_transforms', | |||||
| 'tensor_op_name': 'OneHot', | |||||
| 'num_classes': 10 | |||||
| } | |||||
| ], | |||||
| 'children': [ | |||||
| { | |||||
| 'op_type': 'MnistDataset', | |||||
| 'shard_id': None, | |||||
| 'num_shards': None, | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData', | |||||
| 'num_parallel_workers': None, | |||||
| 'shuffle': None, | |||||
| 'num_samples': 100, | |||||
| 'sampler': { | |||||
| 'sampler_module': 'minddata.dataengine.samplers', | |||||
| 'sampler_name': 'RandomSampler', | |||||
| 'replacement': True, | |||||
| 'num_samples': 100 | |||||
| }, | |||||
| 'children': [] | |||||
| } | |||||
| ] | |||||
| } | |||||
| ] | |||||
| } | |||||
| DATASET_GRAPH = SERIALIZED_PIPELINE | |||||
| def get_module_name(nodeid): | def get_module_name(nodeid): | ||||
| """Get the module name from nodeid.""" | """Get the module name from nodeid.""" | ||||
| @@ -14,6 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Import the mocked mindspore.""" | """Import the mocked mindspore.""" | ||||
| import sys | import sys | ||||
| from .lineagemgr.collection.model import mindspore | |||||
| from ..utils import mindspore | |||||
| sys.modules['mindspore'] = mindspore | sys.modules['mindspore'] = mindspore | ||||
| @@ -21,14 +21,15 @@ Usage: | |||||
| from unittest.mock import patch | from unittest.mock import patch | ||||
| import pytest | import pytest | ||||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||||
| from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||||
| from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||||
| from tests.ut.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager | from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager | ||||
| from ....utils.log_generators.images_log_generator import ImagesLogGenerator | |||||
| from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||||
| from ....utils.tools import get_url | |||||
| from .conftest import TRAIN_ROUTES | |||||
| class TestTrainTask: | class TestTrainTask: | ||||
| """Test train task api.""" | """Test train task api.""" | ||||
| @@ -36,9 +37,7 @@ class TestTrainTask: | |||||
| _scalar_log_generator = ScalarsLogGenerator() | _scalar_log_generator = ScalarsLogGenerator() | ||||
| _image_log_generator = ImagesLogGenerator() | _image_log_generator = ImagesLogGenerator() | ||||
| @pytest.mark.parametrize( | |||||
| "plugin_name", | |||||
| ['no_plugin_name', 'not_exist_plugin_name']) | |||||
| @pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name']) | |||||
| def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): | def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): | ||||
| """ | """ | ||||
| Parsing unavailable plugin name to single train task. | Parsing unavailable plugin name to single train task. | ||||
| @@ -21,14 +21,15 @@ Usage: | |||||
| from unittest.mock import Mock, patch | from unittest.mock import Mock, patch | ||||
| import pytest | import pytest | ||||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||||
| from tests.ut.datavisual.utils.utils import get_url | |||||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | ||||
| from mindinsight.datavisual.processors.graph_processor import GraphProcessor | from mindinsight.datavisual.processors.graph_processor import GraphProcessor | ||||
| from mindinsight.datavisual.processors.images_processor import ImageProcessor | from mindinsight.datavisual.processors.images_processor import ImageProcessor | ||||
| from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | ||||
| from ....utils.tools import get_url | |||||
| from .conftest import TRAIN_ROUTES | |||||
| class TestTrainVisual: | class TestTrainVisual: | ||||
| """Test Train Visual APIs.""" | """Test Train Visual APIs.""" | ||||
| @@ -95,14 +96,7 @@ class TestTrainVisual: | |||||
| assert response.status_code == 200 | assert response.status_code == 200 | ||||
| response = response.get_json() | response = response.get_json() | ||||
| expected_response = { | |||||
| "metadatas": [{ | |||||
| "height": 224, | |||||
| "step": 1, | |||||
| "wall_time": 1572058058.1175, | |||||
| "width": 448 | |||||
| }] | |||||
| } | |||||
| expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]} | |||||
| assert expected_response == response | assert expected_response == response | ||||
| def test_single_image_with_params_miss(self, client): | def test_single_image_with_params_miss(self, client): | ||||
| @@ -254,8 +248,10 @@ class TestTrainVisual: | |||||
| @patch.object(GraphProcessor, 'get_nodes') | @patch.object(GraphProcessor, 'get_nodes') | ||||
| def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): | def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): | ||||
| """Test getting graph nodes successfully.""" | """Test getting graph nodes successfully.""" | ||||
| def mock_get_nodes(name, node_type): | def mock_get_nodes(name, node_type): | ||||
| return dict(name=name, node_type=node_type) | return dict(name=name, node_type=node_type) | ||||
| mock_graph_processor.side_effect = mock_get_nodes | mock_graph_processor.side_effect = mock_get_nodes | ||||
| mock_init = Mock(return_value=None) | mock_init = Mock(return_value=None) | ||||
| @@ -327,10 +323,7 @@ class TestTrainVisual: | |||||
| assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ | assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ | ||||
| "be greater than or equal to 0." | "be greater than or equal to 0." | ||||
| @pytest.mark.parametrize( | |||||
| "limit", | |||||
| [-1, 0, 1001] | |||||
| ) | |||||
| @pytest.mark.parametrize("limit", [-1, 0, 1001]) | |||||
| @patch.object(GraphProcessor, '__init__') | @patch.object(GraphProcessor, '__init__') | ||||
| def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): | def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): | ||||
| """Test getting graph node names with invalid limit.""" | """Test getting graph node names with invalid limit.""" | ||||
| @@ -348,14 +341,10 @@ class TestTrainVisual: | |||||
| assert results['error_msg'] == "Invalid parameter value. " \ | assert results['error_msg'] == "Invalid parameter value. " \ | ||||
| "'limit' should in [1, 1000]." | "'limit' should in [1, 1000]." | ||||
| @pytest.mark.parametrize( | |||||
| " offset, limit", | |||||
| [(0, 100), (1, 1), (0, 1000)] | |||||
| ) | |||||
| @pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)]) | |||||
| @patch.object(GraphProcessor, '__init__') | @patch.object(GraphProcessor, '__init__') | ||||
| @patch.object(GraphProcessor, 'search_node_names') | @patch.object(GraphProcessor, 'search_node_names') | ||||
| def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, | |||||
| offset, limit): | |||||
| def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit): | |||||
| """ | """ | ||||
| Parsing unavailable params to get image metadata. | Parsing unavailable params to get image metadata. | ||||
| @@ -367,8 +356,10 @@ class TestTrainVisual: | |||||
| response status code: 200. | response status code: 200. | ||||
| response json: dict, contains search_content, offset, and limit. | response json: dict, contains search_content, offset, and limit. | ||||
| """ | """ | ||||
| def mock_search_node_names(search_content, offset, limit): | def mock_search_node_names(search_content, offset, limit): | ||||
| return dict(search_content=search_content, offset=int(offset), limit=int(limit)) | return dict(search_content=search_content, offset=int(offset), limit=int(limit)) | ||||
| mock_graph_processor.side_effect = mock_search_node_names | mock_graph_processor.side_effect = mock_search_node_names | ||||
| mock_init = Mock(return_value=None) | mock_init = Mock(return_value=None) | ||||
| @@ -376,15 +367,12 @@ class TestTrainVisual: | |||||
| test_train_id = "aaa" | test_train_id = "aaa" | ||||
| test_search_content = "bbb" | test_search_content = "bbb" | ||||
| params = dict(train_id=test_train_id, search=test_search_content, | |||||
| offset=offset, limit=limit) | |||||
| params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit) | |||||
| url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) | url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) | ||||
| response = client.get(url) | response = client.get(url) | ||||
| assert response.status_code == 200 | assert response.status_code == 200 | ||||
| results = response.get_json() | results = response.get_json() | ||||
| assert results == dict(search_content=test_search_content, | |||||
| offset=int(offset), | |||||
| limit=int(limit)) | |||||
| assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit)) | |||||
| def test_graph_search_single_node_with_params_is_wrong(self, client): | def test_graph_search_single_node_with_params_is_wrong(self, client): | ||||
| """Test searching graph single node with params is wrong.""" | """Test searching graph single node with params is wrong.""" | ||||
| @@ -427,8 +415,10 @@ class TestTrainVisual: | |||||
| response status code: 200. | response status code: 200. | ||||
| response json: name. | response json: name. | ||||
| """ | """ | ||||
| def mock_search_single_node(name): | def mock_search_single_node(name): | ||||
| return name | return name | ||||
| mock_graph_processor.side_effect = mock_search_single_node | mock_graph_processor.side_effect = mock_search_single_node | ||||
| mock_init = Mock(return_value=None) | mock_init = Mock(return_value=None) | ||||
| @@ -20,8 +20,42 @@ from unittest import TestCase, mock | |||||
| from flask import Response | from flask import Response | ||||
| from mindinsight.backend.application import APP | from mindinsight.backend.application import APP | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageQuerySummaryDataError | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError | |||||
| LINEAGE_FILTRATION_BASE = { | |||||
| 'accuracy': None, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 12, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| } | |||||
| LINEAGE_FILTRATION_RUN1 = { | |||||
| 'accuracy': 0.78, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 64, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| } | |||||
| class TestSearchModel(TestCase): | class TestSearchModel(TestCase): | ||||
| @@ -42,39 +76,11 @@ class TestSearchModel(TestCase): | |||||
| 'object': [ | 'object': [ | ||||
| { | { | ||||
| 'summary_dir': base_dir, | 'summary_dir': base_dir, | ||||
| 'accuracy': None, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 12, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| **LINEAGE_FILTRATION_BASE | |||||
| }, | }, | ||||
| { | { | ||||
| 'summary_dir': os.path.join(base_dir, 'run1'), | 'summary_dir': os.path.join(base_dir, 'run1'), | ||||
| 'accuracy': 0.78, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 64, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| **LINEAGE_FILTRATION_RUN1 | |||||
| } | } | ||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| @@ -93,39 +99,11 @@ class TestSearchModel(TestCase): | |||||
| 'object': [ | 'object': [ | ||||
| { | { | ||||
| 'summary_dir': './', | 'summary_dir': './', | ||||
| 'accuracy': None, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': None, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 12, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| **LINEAGE_FILTRATION_BASE | |||||
| }, | }, | ||||
| { | { | ||||
| 'summary_dir': './run1', | 'summary_dir': './run1', | ||||
| 'accuracy': 0.78, | |||||
| 'mae': None, | |||||
| 'mse': None, | |||||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||||
| 'train_dataset_path': None, | |||||
| 'train_dataset_count': 64, | |||||
| 'test_dataset_path': None, | |||||
| 'test_dataset_count': 64, | |||||
| 'network': 'str', | |||||
| 'optimizer': 'Momentum', | |||||
| 'learning_rate': 0.11999999731779099, | |||||
| 'epoch': 14, | |||||
| 'batch_size': 32, | |||||
| 'loss': 0.029999999329447746, | |||||
| 'model_size': 128 | |||||
| **LINEAGE_FILTRATION_RUN1 | |||||
| } | } | ||||
| ], | ], | ||||
| 'count': 2 | 'count': 2 | ||||
| @@ -19,15 +19,16 @@ Usage: | |||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| from unittest.mock import patch | from unittest.mock import patch | ||||
| from werkzeug.exceptions import MethodNotAllowed, NotFound | |||||
| from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES | |||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.utils import get_url | |||||
| from werkzeug.exceptions import MethodNotAllowed, NotFound | |||||
| from mindinsight.datavisual.processors import scalars_processor | from mindinsight.datavisual.processors import scalars_processor | ||||
| from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | ||||
| from ....utils.tools import get_url | |||||
| from ...backend.datavisual.conftest import TRAIN_ROUTES | |||||
| from ..mock import MockLogger | |||||
| class TestErrorHandler: | class TestErrorHandler: | ||||
| """Test train visual api.""" | """Test train visual api.""" | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| Function: | Function: | ||||
| Test mindinsight.datavisual.data_transform.log_generators.data_loader_generator | |||||
| Test mindinsight.datavisual.data_transform.loader_generators.data_loader_generator | |||||
| Usage: | Usage: | ||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| @@ -22,18 +22,19 @@ import datetime | |||||
| import os | import os | ||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| from unittest.mock import patch | from unittest.mock import patch | ||||
| import pytest | |||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| import pytest | |||||
| from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator | from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ...mock import MockLogger | |||||
| class TestDataLoaderGenerator: | class TestDataLoaderGenerator: | ||||
| """Test data_loader_generator.""" | """Test data_loader_generator.""" | ||||
| @classmethod | @classmethod | ||||
| def setup_class(cls): | def setup_class(cls): | ||||
| data_loader_generator.logger = MockLogger | data_loader_generator.logger = MockLogger | ||||
| @@ -88,8 +89,9 @@ class TestDataLoaderGenerator: | |||||
| mock_data_loader.return_value = True | mock_data_loader.return_value = True | ||||
| loader_dict = generator.generate_loaders(loader_pool=dict()) | loader_dict = generator.generate_loaders(loader_pool=dict()) | ||||
| expected_ids = [summary.get('relative_path') | |||||
| for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]] | |||||
| expected_ids = [ | |||||
| summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:] | |||||
| ] | |||||
| assert sorted(loader_dict.keys()) == sorted(expected_ids) | assert sorted(loader_dict.keys()) == sorted(expected_ids) | ||||
| shutil.rmtree(summary_base_dir) | shutil.rmtree(summary_base_dir) | ||||
| @@ -23,12 +23,13 @@ import shutil | |||||
| import tempfile | import tempfile | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid | from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid | ||||
| from mindinsight.datavisual.data_transform import data_loader | from mindinsight.datavisual.data_transform import data_loader | ||||
| from mindinsight.datavisual.data_transform.data_loader import DataLoader | from mindinsight.datavisual.data_transform.data_loader import DataLoader | ||||
| from ..mock import MockLogger | |||||
| class TestDataLoader: | class TestDataLoader: | ||||
| """Test data_loader.""" | """Test data_loader.""" | ||||
| @@ -37,13 +38,13 @@ class TestDataLoader: | |||||
| def setup_class(cls): | def setup_class(cls): | ||||
| data_loader.logger = MockLogger | data_loader.logger = MockLogger | ||||
| def setup_method(self, method): | |||||
| def setup_method(self): | |||||
| self._summary_dir = tempfile.mkdtemp() | self._summary_dir = tempfile.mkdtemp() | ||||
| if os.path.exists(self._summary_dir): | if os.path.exists(self._summary_dir): | ||||
| shutil.rmtree(self._summary_dir) | shutil.rmtree(self._summary_dir) | ||||
| os.mkdir(self._summary_dir) | os.mkdir(self._summary_dir) | ||||
| def teardown_method(self, method): | |||||
| def teardown_method(self): | |||||
| if os.path.exists(self._summary_dir): | if os.path.exists(self._summary_dir): | ||||
| shutil.rmtree(self._summary_dir) | shutil.rmtree(self._summary_dir) | ||||
| @@ -18,32 +18,29 @@ Function: | |||||
| Usage: | Usage: | ||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| import time | |||||
| import os | import os | ||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| import time | |||||
| from unittest import mock | from unittest import mock | ||||
| from unittest.mock import Mock | |||||
| from unittest.mock import patch | |||||
| from unittest.mock import Mock, patch | |||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.utils import check_loading_done | |||||
| from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum | from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum | ||||
| from mindinsight.datavisual.data_transform import data_manager, ms_data_loader | from mindinsight.datavisual.data_transform import data_manager, ms_data_loader | ||||
| from mindinsight.datavisual.data_transform.data_loader import DataLoader | from mindinsight.datavisual.data_transform.data_loader import DataLoader | ||||
| from mindinsight.datavisual.data_transform.data_manager import DataManager | from mindinsight.datavisual.data_transform.data_manager import DataManager | ||||
| from mindinsight.datavisual.data_transform.events_data import EventsData | from mindinsight.datavisual.data_transform.events_data import EventsData | ||||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \ | |||||
| DataLoaderGenerator | |||||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \ | |||||
| MAX_DATA_LOADER_SIZE | |||||
| from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \ | |||||
| LoaderStruct | |||||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator | |||||
| from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE | |||||
| from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct | |||||
| from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ....utils.tools import check_loading_done | |||||
| from ..mock import MockLogger | |||||
| class TestDataManager: | class TestDataManager: | ||||
| """Test data_manager.""" | """Test data_manager.""" | ||||
| @@ -101,11 +98,17 @@ class TestDataManager: | |||||
| "and loader pool size is '3'." | "and loader pool size is '3'." | ||||
| shutil.rmtree(summary_base_dir) | shutil.rmtree(summary_base_dir) | ||||
| @pytest.mark.parametrize('params', | |||||
| [{'reload_interval': '30'}, | |||||
| {'reload_interval': -1}, | |||||
| {'reload_interval': 30, 'max_threads_count': '20'}, | |||||
| {'reload_interval': 30, 'max_threads_count': 0}]) | |||||
| @pytest.mark.parametrize('params', [{ | |||||
| 'reload_interval': '30' | |||||
| }, { | |||||
| 'reload_interval': -1 | |||||
| }, { | |||||
| 'reload_interval': 30, | |||||
| 'max_threads_count': '20' | |||||
| }, { | |||||
| 'reload_interval': 30, | |||||
| 'max_threads_count': 0 | |||||
| }]) | |||||
| def test_start_load_data_with_invalid_params(self, params): | def test_start_load_data_with_invalid_params(self, params): | ||||
| """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" | """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" | ||||
| summary_base_dir = tempfile.mkdtemp() | summary_base_dir = tempfile.mkdtemp() | ||||
| @@ -22,20 +22,24 @@ import threading | |||||
| from collections import namedtuple | from collections import namedtuple | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.datavisual.data_transform import events_data | from mindinsight.datavisual.data_transform import events_data | ||||
| from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor | from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor | ||||
| from ..mock import MockLogger | |||||
| class MockReservoir: | class MockReservoir: | ||||
| """Use this class to replace reservoir.Reservoir in test.""" | """Use this class to replace reservoir.Reservoir in test.""" | ||||
| def __init__(self, size): | def __init__(self, size): | ||||
| self.size = size | self.size = size | ||||
| self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'), | |||||
| _Tensor('wall_time3', 3, 'value3')] | |||||
| self._samples = [ | |||||
| _Tensor('wall_time1', 1, 'value1'), | |||||
| _Tensor('wall_time2', 2, 'value2'), | |||||
| _Tensor('wall_time3', 3, 'value3') | |||||
| ] | |||||
| def samples(self): | def samples(self): | ||||
| """Replace the samples function.""" | """Replace the samples function.""" | ||||
| @@ -63,11 +67,12 @@ class TestEventsData: | |||||
| def setup_method(self): | def setup_method(self): | ||||
| """Mock original logger, init a EventsData object for use.""" | """Mock original logger, init a EventsData object for use.""" | ||||
| self._ev_data = EventsData() | self._ev_data = EventsData() | ||||
| self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)], | |||||
| 'plugin_name2': [f'tag{i}' for i in range(20, 30)]} | |||||
| self._ev_data._tags_by_plugin = { | |||||
| 'plugin_name1': [f'tag{i}' for i in range(10)], | |||||
| 'plugin_name2': [f'tag{i}' for i in range(20, 30)] | |||||
| } | |||||
| self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) | self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) | ||||
| self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), | |||||
| 'new_tag': MockReservoir(500)} | |||||
| self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)} | |||||
| self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] | self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] | ||||
| def get_ev_data(self): | def get_ev_data(self): | ||||
| @@ -102,8 +107,7 @@ class TestEventsData: | |||||
| """Test add_tensor_event success.""" | """Test add_tensor_event success.""" | ||||
| ev_data = self.get_ev_data() | ev_data = self.get_ev_data() | ||||
| t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', | |||||
| value='value1') | |||||
| t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1') | |||||
| ev_data.add_tensor_event(t_event) | ev_data.add_tensor_event(t_event) | ||||
| assert 'tag0' not in ev_data._tags | assert 'tag0' not in ev_data._tags | ||||
| @@ -111,6 +115,5 @@ class TestEventsData: | |||||
| assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] | assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] | ||||
| assert 'tag0' not in ev_data._reservoir_by_tag | assert 'tag0' not in ev_data._reservoir_by_tag | ||||
| assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] | assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] | ||||
| assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, | |||||
| t_event.step, | |||||
| assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step, | |||||
| t_event.value) | t_event.value) | ||||
| @@ -19,16 +19,17 @@ Usage: | |||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| import os | import os | ||||
| import tempfile | |||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| from unittest.mock import Mock | from unittest.mock import Mock | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from mindinsight.datavisual.data_transform import ms_data_loader | from mindinsight.datavisual.data_transform import ms_data_loader | ||||
| from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader | ||||
| from ..mock import MockLogger | |||||
| # bytes of 3 scalar events | # bytes of 3 scalar events | ||||
| SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' | SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' | ||||
| b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' | b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' | ||||
| @@ -74,7 +75,8 @@ class TestMsDataLoader: | |||||
| "we will reload all files in path {}.".format(summary_dir) | "we will reload all files in path {}.".format(summary_dir) | ||||
| shutil.rmtree(summary_dir) | shutil.rmtree(summary_dir) | ||||
| def test_load_success_with_crc_pass(self, crc_pass): | |||||
| @pytest.mark.usefixtures('crc_pass') | |||||
| def test_load_success_with_crc_pass(self): | |||||
| """Test load success.""" | """Test load success.""" | ||||
| summary_dir = tempfile.mkdtemp() | summary_dir = tempfile.mkdtemp() | ||||
| file1 = os.path.join(summary_dir, 'summary.01') | file1 = os.path.join(summary_dir, 'summary.01') | ||||
| @@ -88,7 +90,8 @@ class TestMsDataLoader: | |||||
| tensors = ms_loader.get_events_data().tensors(tag[0]) | tensors = ms_loader.get_events_data().tensors(tag[0]) | ||||
| assert len(tensors) == 3 | assert len(tensors) == 3 | ||||
| def test_load_with_crc_fail(self, crc_fail): | |||||
| @pytest.mark.usefixtures('crc_fail') | |||||
| def test_load_with_crc_fail(self): | |||||
| """Test when crc_fail and will not go to func _event_parse.""" | """Test when crc_fail and will not go to func _event_parse.""" | ||||
| summary_dir = tempfile.mkdtemp() | summary_dir = tempfile.mkdtemp() | ||||
| file2 = os.path.join(summary_dir, 'summary.02') | file2 = os.path.join(summary_dir, 'summary.02') | ||||
| @@ -100,8 +103,10 @@ class TestMsDataLoader: | |||||
| def test_filter_event_files(self): | def test_filter_event_files(self): | ||||
| """Test filter_event_files function ok.""" | """Test filter_event_files function ok.""" | ||||
| file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', | |||||
| 'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567'] | |||||
| file_list = [ | |||||
| 'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786', | |||||
| 'mysummary.123abce', 'summay.4567' | |||||
| ] | |||||
| summary_dir = tempfile.mkdtemp() | summary_dir = tempfile.mkdtemp() | ||||
| for file in file_list: | for file in file_list: | ||||
| with open(os.path.join(summary_dir, file), 'w'): | with open(os.path.join(summary_dir, file), 'w'): | ||||
| @@ -113,6 +118,7 @@ class TestMsDataLoader: | |||||
| shutil.rmtree(summary_dir) | shutil.rmtree(summary_dir) | ||||
| def write_file(filename, record): | def write_file(filename, record): | ||||
| """Write bytes strings to file.""" | """Write bytes strings to file.""" | ||||
| with open(filename, 'wb') as file: | with open(filename, 'wb') as file: | ||||
| @@ -19,18 +19,11 @@ Usage: | |||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| import os | import os | ||||
| import json | |||||
| import tempfile | import tempfile | ||||
| from unittest.mock import Mock | |||||
| from unittest.mock import patch | |||||
| from unittest.mock import Mock, patch | |||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||||
| from mindinsight.datavisual.common import exceptions | from mindinsight.datavisual.common import exceptions | ||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| @@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ....utils.log_operations import LogOperations | |||||
| from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs | |||||
| from ..mock import MockLogger | |||||
| class TestGraphProcessor: | class TestGraphProcessor: | ||||
| """Test Graph Processor api.""" | """Test Graph Processor api.""" | ||||
| @@ -70,18 +67,13 @@ class TestGraphProcessor: | |||||
| """Load graph record.""" | """Load graph record.""" | ||||
| summary_base_dir = tempfile.mkdtemp() | summary_base_dir = tempfile.mkdtemp() | ||||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | log_dir = tempfile.mkdtemp(dir=summary_base_dir) | ||||
| self._train_id = log_dir.replace(summary_base_dir, ".") | self._train_id = log_dir.replace(summary_base_dir, ".") | ||||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||||
| self._temp_path, self._graph_dict = LogOperations.generate_log( | |||||
| PluginNameEnum.GRAPH.value, log_dir, dict(graph_base_path=graph_base_path)) | |||||
| log_operation = LogOperations() | |||||
| self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) | |||||
| self._generated_path.append(summary_base_dir) | self._generated_path.append(summary_base_dir) | ||||
| self._mock_data_manager = data_manager.DataManager( | |||||
| [DataLoaderGenerator(summary_base_dir)]) | |||||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||||
| self._mock_data_manager.start_load_data(reload_interval=0) | self._mock_data_manager.start_load_data(reload_interval=0) | ||||
| # wait for loading done | # wait for loading done | ||||
| @@ -94,33 +86,29 @@ class TestGraphProcessor: | |||||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | log_dir = tempfile.mkdtemp(dir=summary_base_dir) | ||||
| self._train_id = log_dir.replace(summary_base_dir, ".") | self._train_id = log_dir.replace(summary_base_dir, ".") | ||||
| self._temp_path, _, _ = LogOperations.generate_log( | |||||
| PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) | |||||
| log_operation = LogOperations() | |||||
| self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir, | |||||
| dict(steps=self._steps_list, tag="image")) | |||||
| self._generated_path.append(summary_base_dir) | self._generated_path.append(summary_base_dir) | ||||
| self._mock_data_manager = data_manager.DataManager( | |||||
| [DataLoaderGenerator(summary_base_dir)]) | |||||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||||
| self._mock_data_manager.start_load_data(reload_interval=0) | self._mock_data_manager.start_load_data(reload_interval=0) | ||||
| # wait for loading done | # wait for loading done | ||||
| check_loading_done(self._mock_data_manager, time_limit=5) | check_loading_done(self._mock_data_manager, time_limit=5) | ||||
| def compare_result_with_file(self, result, filename): | |||||
| """Compare result with file which contain the expected results.""" | |||||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||||
| expected_results = json.load(fp) | |||||
| assert result == expected_results | |||||
| def test_get_nodes_with_not_exist_train_id(self, load_graph_record): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| def test_get_nodes_with_not_exist_train_id(self): | |||||
| """Test getting nodes with not exist train id.""" | """Test getting nodes with not exist train id.""" | ||||
| test_train_id = "not_exist_train_id" | test_train_id = "not_exist_train_id" | ||||
| with pytest.raises(ParamValueError) as exc_info: | with pytest.raises(ParamValueError) as exc_info: | ||||
| GraphProcessor(test_train_id, self._mock_data_manager) | GraphProcessor(test_train_id, self._mock_data_manager) | ||||
| assert "Can not find the train job in data manager." in exc_info.value.message | assert "Can not find the train job in data manager." in exc_info.value.message | ||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @patch.object(DataManager, 'get_train_job_by_plugin') | @patch.object(DataManager, 'get_train_job_by_plugin') | ||||
| def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record): | |||||
| def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin): | |||||
| """Test get nodes with loader is None.""" | """Test get nodes with loader is None.""" | ||||
| mock_get_train_job_by_plugin.return_value = None | mock_get_train_job_by_plugin.return_value = None | ||||
| with pytest.raises(exceptions.SummaryLogPathInvalid): | with pytest.raises(exceptions.SummaryLogPathInvalid): | ||||
| @@ -128,15 +116,12 @@ class TestGraphProcessor: | |||||
| assert mock_get_train_job_by_plugin.called | assert mock_get_train_job_by_plugin.called | ||||
| @pytest.mark.parametrize("name, node_type", [ | |||||
| ("not_exist_name", "name_scope"), | |||||
| ("", "polymeric_scope") | |||||
| ]) | |||||
| def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")]) | |||||
| def test_get_nodes_with_not_exist_name(self, name, node_type): | |||||
| """Test getting nodes with not exist name.""" | """Test getting nodes with not exist name.""" | ||||
| with pytest.raises(ParamValueError) as exc_info: | with pytest.raises(ParamValueError) as exc_info: | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| graph_processor.get_nodes(name, node_type) | graph_processor.get_nodes(name, node_type) | ||||
| if name: | if name: | ||||
| @@ -144,105 +129,99 @@ class TestGraphProcessor: | |||||
| else: | else: | ||||
| assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message | assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message | ||||
| @pytest.mark.parametrize("name, node_type, result_file", [ | |||||
| (None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), | |||||
| ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), | |||||
| ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json') | |||||
| ]) | |||||
| def test_get_nodes_success(self, load_graph_record, name, node_type, result_file): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize( | |||||
| "name, node_type, result_file", | |||||
| [(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), | |||||
| ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), | |||||
| ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')]) | |||||
| def test_get_nodes_success(self, name, node_type, result_file): | |||||
| """Test getting nodes successfully.""" | """Test getting nodes successfully.""" | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| results = graph_processor.get_nodes(name, node_type) | results = graph_processor.get_nodes(name, node_type) | ||||
| self.compare_result_with_file(results, result_file) | |||||
| @pytest.mark.parametrize("search_content, result_file", [ | |||||
| (None, 'test_search_node_names_with_search_content_expected_results1.json'), | |||||
| ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), | |||||
| ('not_exist_search_content', None) | |||||
| ]) | |||||
| def test_search_node_names_with_search_content(self, load_graph_record, | |||||
| search_content, | |||||
| result_file): | |||||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(results, expected_file_path) | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize("search_content, result_file", | |||||
| [(None, 'test_search_node_names_with_search_content_expected_results1.json'), | |||||
| ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), | |||||
| ('not_exist_search_content', None)]) | |||||
| def test_search_node_names_with_search_content(self, search_content, result_file): | |||||
| """Test search node names with search content.""" | """Test search node names with search content.""" | ||||
| test_offset = 0 | test_offset = 0 | ||||
| test_limit = 1000 | test_limit = 1000 | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| results = graph_processor.search_node_names(search_content, | |||||
| test_offset, | |||||
| test_limit) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| results = graph_processor.search_node_names(search_content, test_offset, test_limit) | |||||
| if search_content == 'not_exist_search_content': | if search_content == 'not_exist_search_content': | ||||
| expected_results = {'names': []} | expected_results = {'names': []} | ||||
| assert results == expected_results | assert results == expected_results | ||||
| else: | else: | ||||
| self.compare_result_with_file(results, result_file) | |||||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(results, expected_file_path) | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize("offset", [-100, -1]) | @pytest.mark.parametrize("offset", [-100, -1]) | ||||
| def test_search_node_names_with_negative_offset(self, load_graph_record, offset): | |||||
| def test_search_node_names_with_negative_offset(self, offset): | |||||
| """Test search node names with negative offset.""" | """Test search node names with negative offset.""" | ||||
| test_search_content = "" | test_search_content = "" | ||||
| test_limit = 3 | test_limit = 3 | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| with pytest.raises(ParamValueError) as exc_info: | with pytest.raises(ParamValueError) as exc_info: | ||||
| graph_processor.search_node_names(test_search_content, offset, test_limit) | graph_processor.search_node_names(test_search_content, offset, test_limit) | ||||
| assert "'offset' should be greater than or equal to 0." in exc_info.value.message | assert "'offset' should be greater than or equal to 0." in exc_info.value.message | ||||
| @pytest.mark.parametrize("offset, result_file", [ | |||||
| (1, 'test_search_node_names_with_offset_expected_results1.json') | |||||
| ]) | |||||
| def test_search_node_names_with_offset(self, load_graph_record, offset, result_file): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')]) | |||||
| def test_search_node_names_with_offset(self, offset, result_file): | |||||
| """Test search node names with offset.""" | """Test search node names with offset.""" | ||||
| test_search_content = "Default/bn1" | test_search_content = "Default/bn1" | ||||
| test_offset = offset | test_offset = offset | ||||
| test_limit = 3 | test_limit = 3 | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| results = graph_processor.search_node_names(test_search_content, | |||||
| test_offset, | |||||
| test_limit) | |||||
| self.compare_result_with_file(results, result_file) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| results = graph_processor.search_node_names(test_search_content, test_offset, test_limit) | |||||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(results, expected_file_path) | |||||
| def test_search_node_names_with_wrong_limit(self, load_graph_record): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| def test_search_node_names_with_wrong_limit(self): | |||||
| """Test search node names with wrong limit.""" | """Test search node names with wrong limit.""" | ||||
| test_search_content = "" | test_search_content = "" | ||||
| test_offset = 0 | test_offset = 0 | ||||
| test_limit = 0 | test_limit = 0 | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| with pytest.raises(ParamValueError) as exc_info: | with pytest.raises(ParamValueError) as exc_info: | ||||
| graph_processor.search_node_names(test_search_content, test_offset, | |||||
| test_limit) | |||||
| graph_processor.search_node_names(test_search_content, test_offset, test_limit) | |||||
| assert "'limit' should in [1, 1000]." in exc_info.value.message | assert "'limit' should in [1, 1000]." in exc_info.value.message | ||||
| @pytest.mark.parametrize("name, result_file", [ | |||||
| ('Default/bn1', 'test_search_single_node_success_expected_results1.json') | |||||
| ]) | |||||
| def test_search_single_node_success(self, load_graph_record, name, result_file): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| @pytest.mark.parametrize("name, result_file", | |||||
| [('Default/bn1', 'test_search_single_node_success_expected_results1.json')]) | |||||
| def test_search_single_node_success(self, name, result_file): | |||||
| """Test searching single node successfully.""" | """Test searching single node successfully.""" | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| results = graph_processor.search_single_node(name) | results = graph_processor.search_single_node(name) | ||||
| self.compare_result_with_file(results, result_file) | |||||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||||
| compare_result_with_file(results, expected_file_path) | |||||
| def test_search_single_node_with_not_exist_name(self, load_graph_record): | |||||
| @pytest.mark.usefixtures('load_graph_record') | |||||
| def test_search_single_node_with_not_exist_name(self): | |||||
| """Test searching single node with not exist name.""" | """Test searching single node with not exist name.""" | ||||
| test_name = "not_exist_name" | test_name = "not_exist_name" | ||||
| with pytest.raises(exceptions.NodeNotInGraphError): | with pytest.raises(exceptions.NodeNotInGraphError): | ||||
| graph_processor = GraphProcessor(self._train_id, | |||||
| self._mock_data_manager) | |||||
| graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) | |||||
| graph_processor.search_single_node(test_name) | graph_processor.search_single_node(test_name) | ||||
| def test_check_graph_status_no_graph(self, load_no_graph_record): | |||||
| @pytest.mark.usefixtures('load_no_graph_record') | |||||
| def test_check_graph_status_no_graph(self): | |||||
| """Test checking graph status no graph.""" | """Test checking graph status no graph.""" | ||||
| with pytest.raises(ParamValueError) as exc_info: | with pytest.raises(ParamValueError) as exc_info: | ||||
| GraphProcessor(self._train_id, self._mock_data_manager) | GraphProcessor(self._train_id, self._mock_data_manager) | ||||
| @@ -22,9 +22,6 @@ import tempfile | |||||
| from unittest.mock import Mock | from unittest.mock import Mock | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ....utils.log_operations import LogOperations | |||||
| from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes | |||||
| from ..mock import MockLogger | |||||
| class TestImagesProcessor: | class TestImagesProcessor: | ||||
| """Test images processor api.""" | """Test images processor api.""" | ||||
| @@ -73,12 +74,11 @@ class TestImagesProcessor: | |||||
| """ | """ | ||||
| summary_base_dir = tempfile.mkdtemp() | summary_base_dir = tempfile.mkdtemp() | ||||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | log_dir = tempfile.mkdtemp(dir=summary_base_dir) | ||||
| self._train_id = log_dir.replace(summary_base_dir, ".") | self._train_id = log_dir.replace(summary_base_dir, ".") | ||||
| self._temp_path, self._images_metadata, self._images_values = LogOperations.generate_log( | |||||
| log_operation = LogOperations() | |||||
| self._temp_path, self._images_metadata, self._images_values = log_operation.generate_log( | |||||
| PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) | PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) | ||||
| self._generated_path.append(summary_base_dir) | self._generated_path.append(summary_base_dir) | ||||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | ||||
| @@ -102,7 +102,8 @@ class TestImagesProcessor: | |||||
| """Load image record.""" | """Load image record.""" | ||||
| self._init_data_manager(self._cross_steps_list) | self._init_data_manager(self._cross_steps_list) | ||||
| def test_get_metadata_list_with_not_exist_id(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_metadata_list_with_not_exist_id(self): | |||||
| """Test getting metadata list with not exist id.""" | """Test getting metadata list with not exist id.""" | ||||
| test_train_id = 'not_exist_id' | test_train_id = 'not_exist_id' | ||||
| image_processor = ImageProcessor(self._mock_data_manager) | image_processor = ImageProcessor(self._mock_data_manager) | ||||
| @@ -112,7 +113,8 @@ class TestImagesProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | assert "Can not find any data in loader pool about the train job." in exc_info.value.message | ||||
| def test_get_metadata_list_with_not_exist_tag(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_metadata_list_with_not_exist_tag(self): | |||||
| """Test get metadata list with not exist tag.""" | """Test get metadata list with not exist tag.""" | ||||
| test_tag_name = 'not_exist_tag_name' | test_tag_name = 'not_exist_tag_name' | ||||
| @@ -124,7 +126,8 @@ class TestImagesProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | assert "Can not find any data in this train job by given tag." in exc_info.value.message | ||||
| def test_get_metadata_list_success(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_metadata_list_success(self): | |||||
| """Test getting metadata list success.""" | """Test getting metadata list success.""" | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| @@ -133,7 +136,8 @@ class TestImagesProcessor: | |||||
| assert results == self._images_metadata | assert results == self._images_metadata | ||||
| def test_get_single_image_with_not_exist_id(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_single_image_with_not_exist_id(self): | |||||
| """Test getting single image with not exist id.""" | """Test getting single image with not exist id.""" | ||||
| test_train_id = 'not_exist_id' | test_train_id = 'not_exist_id' | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| @@ -146,7 +150,8 @@ class TestImagesProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | assert "Can not find any data in loader pool about the train job." in exc_info.value.message | ||||
| def test_get_single_image_with_not_exist_tag(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_single_image_with_not_exist_tag(self): | |||||
| """Test getting single image with not exist tag.""" | """Test getting single image with not exist tag.""" | ||||
| test_tag_name = 'not_exist_tag_name' | test_tag_name = 'not_exist_tag_name' | ||||
| test_step = self._steps_list[0] | test_step = self._steps_list[0] | ||||
| @@ -159,7 +164,8 @@ class TestImagesProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | assert "Can not find any data in this train job by given tag." in exc_info.value.message | ||||
| def test_get_single_image_with_not_exist_step(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_single_image_with_not_exist_step(self): | |||||
| """Test getting single image with not exist step.""" | """Test getting single image with not exist step.""" | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| test_step = 10000 | test_step = 10000 | ||||
| @@ -172,24 +178,22 @@ class TestImagesProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find the step with given train job id and tag." in exc_info.value.message | assert "Can not find the step with given train job id and tag." in exc_info.value.message | ||||
| def test_get_single_image_success(self, load_image_record): | |||||
| @pytest.mark.usefixtures('load_image_record') | |||||
| def test_get_single_image_success(self): | |||||
| """Test getting single image successfully.""" | """Test getting single image successfully.""" | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| test_step_index = 0 | test_step_index = 0 | ||||
| test_step = self._steps_list[test_step_index] | test_step = self._steps_list[test_step_index] | ||||
| expected_image_tensor = self._images_values.get(test_step) | |||||
| image_processor = ImageProcessor(self._mock_data_manager) | image_processor = ImageProcessor(self._mock_data_manager) | ||||
| results = image_processor.get_single_image(self._train_id, test_tag_name, test_step) | results = image_processor.get_single_image(self._train_id, test_tag_name, test_step) | ||||
| expected_image_tensor = self._images_values.get(test_step) | |||||
| image_generator = LogOperations.get_log_generator(PluginNameEnum.IMAGE.value) | |||||
| recv_image_tensor = image_generator.get_image_tensor_from_bytes(results) | |||||
| recv_image_tensor = get_image_tensor_from_bytes(results) | |||||
| assert recv_image_tensor.any() == expected_image_tensor.any() | assert recv_image_tensor.any() == expected_image_tensor.any() | ||||
| def test_reservoir_add_sample(self, load_more_than_limit_image_record): | |||||
| @pytest.mark.usefixtures('load_more_than_limit_image_record') | |||||
| def test_reservoir_add_sample(self): | |||||
| """Test adding sample in reservoir.""" | """Test adding sample in reservoir.""" | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| @@ -206,7 +210,8 @@ class TestImagesProcessor: | |||||
| cnt += 1 | cnt += 1 | ||||
| assert len(self._more_steps_list) - cnt == 10 | assert len(self._more_steps_list) - cnt == 10 | ||||
| def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record): | |||||
| @pytest.mark.usefixtures('load_reservoir_remove_sample_image_record') | |||||
| def test_reservoir_remove_sample(self): | |||||
| """ | """ | ||||
| Test removing sample in reservoir. | Test removing sample in reservoir. | ||||
| @@ -22,9 +22,6 @@ import tempfile | |||||
| from unittest.mock import Mock | from unittest.mock import Mock | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| @@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ....utils.log_operations import LogOperations | |||||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||||
| from ..mock import MockLogger | |||||
| class TestScalarsProcessor: | class TestScalarsProcessor: | ||||
| """Test scalar processor api.""" | """Test scalar processor api.""" | ||||
| @@ -65,12 +66,11 @@ class TestScalarsProcessor: | |||||
| """Load scalar record.""" | """Load scalar record.""" | ||||
| summary_base_dir = tempfile.mkdtemp() | summary_base_dir = tempfile.mkdtemp() | ||||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | log_dir = tempfile.mkdtemp(dir=summary_base_dir) | ||||
| self._train_id = log_dir.replace(summary_base_dir, ".") | self._train_id = log_dir.replace(summary_base_dir, ".") | ||||
| self._temp_path, self._scalars_metadata, self._scalars_values = LogOperations.generate_log( | |||||
| log_operation = LogOperations() | |||||
| self._temp_path, self._scalars_metadata, self._scalars_values = log_operation.generate_log( | |||||
| PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) | PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) | ||||
| self._generated_path.append(summary_base_dir) | self._generated_path.append(summary_base_dir) | ||||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | ||||
| @@ -79,7 +79,8 @@ class TestScalarsProcessor: | |||||
| # wait for loading done | # wait for loading done | ||||
| check_loading_done(self._mock_data_manager, time_limit=5) | check_loading_done(self._mock_data_manager, time_limit=5) | ||||
| def test_get_metadata_list_with_not_exist_id(self, load_scalar_record): | |||||
| @pytest.mark.usefixtures('load_scalar_record') | |||||
| def test_get_metadata_list_with_not_exist_id(self): | |||||
| """Get metadata list with not exist id.""" | """Get metadata list with not exist id.""" | ||||
| test_train_id = 'not_exist_id' | test_train_id = 'not_exist_id' | ||||
| scalar_processor = ScalarsProcessor(self._mock_data_manager) | scalar_processor = ScalarsProcessor(self._mock_data_manager) | ||||
| @@ -89,7 +90,8 @@ class TestScalarsProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in loader pool about the train job." in exc_info.value.message | assert "Can not find any data in loader pool about the train job." in exc_info.value.message | ||||
| def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record): | |||||
| @pytest.mark.usefixtures('load_scalar_record') | |||||
| def test_get_metadata_list_with_not_exist_tag(self): | |||||
| """Get metadata list with not exist tag.""" | """Get metadata list with not exist tag.""" | ||||
| test_tag_name = 'not_exist_tag_name' | test_tag_name = 'not_exist_tag_name' | ||||
| @@ -101,7 +103,8 @@ class TestScalarsProcessor: | |||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | assert "Can not find any data in this train job by given tag." in exc_info.value.message | ||||
| def test_get_metadata_list_success(self, load_scalar_record): | |||||
| @pytest.mark.usefixtures('load_scalar_record') | |||||
| def test_get_metadata_list_success(self): | |||||
| """Get metadata list success.""" | """Get metadata list success.""" | ||||
| test_tag_name = self._complete_tag_name | test_tag_name = self._complete_tag_name | ||||
| @@ -18,15 +18,11 @@ Function: | |||||
| Usage: | Usage: | ||||
| pytest tests/ut/datavisual | pytest tests/ut/datavisual | ||||
| """ | """ | ||||
| import os | |||||
| import tempfile | import tempfile | ||||
| import time | import time | ||||
| from unittest.mock import Mock | from unittest.mock import Mock | ||||
| import pytest | import pytest | ||||
| from tests.ut.datavisual.mock import MockLogger | |||||
| from tests.ut.datavisual.utils.log_operations import LogOperations | |||||
| from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| @@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from ....utils.log_operations import LogOperations | |||||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||||
| from ..mock import MockLogger | |||||
| class TestTrainTaskManager: | class TestTrainTaskManager: | ||||
| """Test train task manager.""" | """Test train task manager.""" | ||||
| @@ -70,39 +70,30 @@ class TestTrainTaskManager: | |||||
| @pytest.fixture(scope='function') | @pytest.fixture(scope='function') | ||||
| def load_data(self): | def load_data(self): | ||||
| """Load data.""" | """Load data.""" | ||||
| log_operation = LogOperations() | |||||
| self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []} | self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []} | ||||
| self._events_names = [] | self._events_names = [] | ||||
| self._train_id_list = [] | self._train_id_list = [] | ||||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||||
| self._root_dir = tempfile.mkdtemp() | self._root_dir = tempfile.mkdtemp() | ||||
| for i in range(self._dir_num): | for i in range(self._dir_num): | ||||
| dir_path = tempfile.mkdtemp(dir=self._root_dir) | dir_path = tempfile.mkdtemp(dir=self._root_dir) | ||||
| tmp_tag_name = self._tag_name + '_' + str(i) | tmp_tag_name = self._tag_name + '_' + str(i) | ||||
| event_name = str(i) + "_name" | event_name = str(i) + "_name" | ||||
| train_id = dir_path.replace(self._root_dir, ".") | train_id = dir_path.replace(self._root_dir, ".") | ||||
| # Pass timestamp to write to the same file. | # Pass timestamp to write to the same file. | ||||
| log_settings = dict( | |||||
| steps=self._steps_list, | |||||
| tag=tmp_tag_name, | |||||
| graph_base_path=graph_base_path, | |||||
| time=time.time()) | |||||
| log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time()) | |||||
| if i % 3 != 0: | if i % 3 != 0: | ||||
| LogOperations.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) | |||||
| log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) | |||||
| self._plugins_id_map['image'].append(train_id) | self._plugins_id_map['image'].append(train_id) | ||||
| if i % 3 != 1: | if i % 3 != 1: | ||||
| LogOperations.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) | |||||
| log_operation.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) | |||||
| self._plugins_id_map['scalar'].append(train_id) | self._plugins_id_map['scalar'].append(train_id) | ||||
| if i % 3 != 2: | if i % 3 != 2: | ||||
| LogOperations.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) | |||||
| log_operation.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) | |||||
| self._plugins_id_map['graph'].append(train_id) | self._plugins_id_map['graph'].append(train_id) | ||||
| self._events_names.append(event_name) | self._events_names.append(event_name) | ||||
| self._train_id_list.append(train_id) | self._train_id_list.append(train_id) | ||||
| self._generated_path.append(self._root_dir) | self._generated_path.append(self._root_dir) | ||||
| @@ -112,7 +103,8 @@ class TestTrainTaskManager: | |||||
| check_loading_done(self._mock_data_manager, time_limit=30) | check_loading_done(self._mock_data_manager, time_limit=30) | ||||
| def test_get_single_train_task_with_not_exists_train_id(self, load_data): | |||||
| @pytest.mark.usefixtures('load_data') | |||||
| def test_get_single_train_task_with_not_exists_train_id(self): | |||||
| """Test getting single train task with not exists train_id.""" | """Test getting single train task with not exists train_id.""" | ||||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | train_task_manager = TrainTaskManager(self._mock_data_manager) | ||||
| for plugin_name in PluginNameEnum.list_members(): | for plugin_name in PluginNameEnum.list_members(): | ||||
| @@ -124,7 +116,8 @@ class TestTrainTaskManager: | |||||
| "the train job in data manager." | "the train job in data manager." | ||||
| assert exc_info.value.error_code == '50540002' | assert exc_info.value.error_code == '50540002' | ||||
| def test_get_single_train_task_with_params(self, load_data): | |||||
| @pytest.mark.usefixtures('load_data') | |||||
| def test_get_single_train_task_with_params(self): | |||||
| """Test getting single train task with params.""" | """Test getting single train task with params.""" | ||||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | train_task_manager = TrainTaskManager(self._mock_data_manager) | ||||
| for plugin_name in PluginNameEnum.list_members(): | for plugin_name in PluginNameEnum.list_members(): | ||||
| @@ -138,7 +131,8 @@ class TestTrainTaskManager: | |||||
| else: | else: | ||||
| assert test_train_id not in self._plugins_id_map.get(plugin_name) | assert test_train_id not in self._plugins_id_map.get(plugin_name) | ||||
| def test_get_plugins_with_train_id(self, load_data): | |||||
| @pytest.mark.usefixtures('load_data') | |||||
| def test_get_plugins_with_train_id(self): | |||||
| """Test getting plugins with train id.""" | """Test getting plugins with train id.""" | ||||
| train_task_manager = TrainTaskManager(self._mock_data_manager) | train_task_manager = TrainTaskManager(self._mock_data_manager) | ||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,85 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mask string to crc32.""" | |||||
| CRC_TABLE_32 = ( | |||||
| 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF, | |||||
| 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C, | |||||
| 0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, | |||||
| 0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, | |||||
| 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E, | |||||
| 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD, | |||||
| 0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696, | |||||
| 0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, | |||||
| 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3, | |||||
| 0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F, | |||||
| 0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395, | |||||
| 0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, | |||||
| 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312, | |||||
| 0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE, | |||||
| 0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90, | |||||
| 0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, | |||||
| 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8, | |||||
| 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5, | |||||
| 0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, | |||||
| 0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, | |||||
| 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19, | |||||
| 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8, | |||||
| 0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, | |||||
| 0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, | |||||
| 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A, | |||||
| 0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6, | |||||
| 0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82, | |||||
| 0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, | |||||
| 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351 | |||||
| ) | |||||
| _CRC = 0 | |||||
| _MASK = 0xFFFFFFFF | |||||
| def _uint32(x): | |||||
| """Transform x's type to uint32.""" | |||||
| return x & 0xFFFFFFFF | |||||
| def _get_crc_checksum(crc, data): | |||||
| """Get crc checksum.""" | |||||
| crc ^= _MASK | |||||
| for d in data: | |||||
| crc_table_index = (crc ^ d) & 0xFF | |||||
| crc = (CRC_TABLE_32[crc_table_index] ^ (crc >> 8)) & _MASK | |||||
| crc ^= _MASK | |||||
| return crc | |||||
| def get_mask_from_string(data): | |||||
| """ | |||||
| Get masked crc from data. | |||||
| Args: | |||||
| data (byte): Byte string of data. | |||||
| Returns: | |||||
| uint32, masked crc. | |||||
| """ | |||||
| crc = _get_crc_checksum(_CRC, data) | |||||
| crc = _uint32(crc & _MASK) | |||||
| crc = _uint32(((crc >> 15) | _uint32(crc << 17)) + 0xA282EAD8) | |||||
| return crc | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,166 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Log generator for images.""" | |||||
| import io | |||||
| import time | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||||
| class ImagesLogGenerator(LogGenerator): | |||||
| """ | |||||
| Log generator for images. | |||||
| This is a log generator writing images. User can use it to generate fake | |||||
| summary logs about images. | |||||
| """ | |||||
| def generate_event(self, values): | |||||
| """ | |||||
| Method for generating image event. | |||||
| Args: | |||||
| values (dict): A dict contains: | |||||
| { | |||||
| wall_time (float): Timestamp. | |||||
| step (int): Train step. | |||||
| image (np.array): Pixels tensor. | |||||
| tag (str): Tag name. | |||||
| } | |||||
| Returns: | |||||
| summary_pb2.Event. | |||||
| """ | |||||
| image_event = summary_pb2.Event() | |||||
| image_event.wall_time = values.get('wall_time') | |||||
| image_event.step = values.get('step') | |||||
| height, width, channel, image_string = self._get_image_string(values.get('image')) | |||||
| value = image_event.summary.value.add() | |||||
| value.tag = values.get('tag') | |||||
| value.image.height = height | |||||
| value.image.width = width | |||||
| value.image.colorspace = channel | |||||
| value.image.encoded_image = image_string | |||||
| return image_event | |||||
| def _get_image_string(self, image_tensor): | |||||
| """ | |||||
| Generate image string from tensor. | |||||
| Args: | |||||
| image_tensor (np.array): Pixels tensor. | |||||
| Returns: | |||||
| int, height. | |||||
| int, width. | |||||
| int, channel. | |||||
| bytes, image_string. | |||||
| """ | |||||
| height, width, channel = image_tensor.shape | |||||
| scaled_height = int(height) | |||||
| scaled_width = int(width) | |||||
| image = Image.fromarray(image_tensor) | |||||
| image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS) | |||||
| output = io.BytesIO() | |||||
| image.save(output, format='PNG') | |||||
| image_string = output.getvalue() | |||||
| output.close() | |||||
| return height, width, channel, image_string | |||||
| def _make_image_tensor(self, shape): | |||||
| """ | |||||
| Make image tensor according to shape. | |||||
| Args: | |||||
| shape (list): Shape of image, consists of height, width, channel. | |||||
| Returns: | |||||
| np.array, image tensor. | |||||
| """ | |||||
| image = np.prod(shape) | |||||
| image_tensor = (np.arange(image, dtype=float)).reshape(shape) | |||||
| image_tensor = image_tensor / np.max(image_tensor) * 255 | |||||
| image_tensor = image_tensor.astype(np.uint8) | |||||
| return image_tensor | |||||
| def generate_log(self, file_path, steps_list, tag_name): | |||||
| """ | |||||
| Generate log for external calls. | |||||
| Args: | |||||
| file_path (str): Path to write logs. | |||||
| steps_list (list): A list consists of step. | |||||
| tag_name (str): Tag name. | |||||
| Returns: | |||||
| list[dict], generated image metadata. | |||||
| dict, generated image tensors. | |||||
| """ | |||||
| images_values = dict() | |||||
| images_metadata = [] | |||||
| for step in steps_list: | |||||
| wall_time = time.time() | |||||
| # height, width, channel | |||||
| image_tensor = self._make_image_tensor([5, 5, 3]) | |||||
| image_metadata = dict() | |||||
| image_metadata.update({'wall_time': wall_time}) | |||||
| image_metadata.update({'step': step}) | |||||
| image_metadata.update({'height': image_tensor.shape[0]}) | |||||
| image_metadata.update({'width': image_tensor.shape[1]}) | |||||
| images_metadata.append(image_metadata) | |||||
| images_values.update({step: image_tensor}) | |||||
| values = dict( | |||||
| wall_time=wall_time, | |||||
| step=step, | |||||
| image=image_tensor, | |||||
| tag=tag_name | |||||
| ) | |||||
| self._write_log_one_step(file_path, values) | |||||
| return images_metadata, images_values | |||||
| def get_image_tensor_from_bytes(self, image_string): | |||||
| """Get image tensor from bytes.""" | |||||
| img = Image.open(io.BytesIO(image_string)) | |||||
| image_tensor = np.array(img) | |||||
| return image_tensor | |||||
| if __name__ == "__main__": | |||||
| images_log_generator = ImagesLogGenerator() | |||||
| test_file_name = '%s.%s.%s' % ('image', 'summary', str(time.time())) | |||||
| test_steps = [1, 3, 5] | |||||
| test_tags = "test_image_tag_name" | |||||
| images_log_generator.generate_log(test_file_name, test_steps, test_tags) | |||||
| @@ -1,75 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Base log Generator.""" | |||||
| import struct | |||||
| from abc import abstractmethod | |||||
| from tests.ut.datavisual.utils import crc32 | |||||
| class LogGenerator: | |||||
| """ | |||||
| Base log generator. | |||||
| This is a base class for log generators. User can use it to generate fake | |||||
| summary logs. | |||||
| """ | |||||
| @abstractmethod | |||||
| def generate_event(self, values): | |||||
| """ | |||||
| Abstract method for generating event. | |||||
| Args: | |||||
| values (dict): Values. | |||||
| Returns: | |||||
| summary_pb2.Event. | |||||
| """ | |||||
| def _write_log_one_step(self, file_path, values): | |||||
| """ | |||||
| Write log one step. | |||||
| Args: | |||||
| file_path (str): File path to write. | |||||
| values (dict): Values. | |||||
| """ | |||||
| event = self.generate_event(values) | |||||
| self._write_log_from_event(file_path, event) | |||||
| @staticmethod | |||||
| def _write_log_from_event(file_path, event): | |||||
| """ | |||||
| Write log by event. | |||||
| Args: | |||||
| file_path (str): File path to write. | |||||
| event (summary_pb2.Event): Event object in proto. | |||||
| """ | |||||
| send_msg = event.SerializeToString() | |||||
| header = struct.pack('<Q', len(send_msg)) | |||||
| header_crc = struct.pack('<I', crc32.get_mask_from_string(header)) | |||||
| footer_crc = struct.pack('<I', crc32.get_mask_from_string(send_msg)) | |||||
| write_event = header + header_crc + send_msg + footer_crc | |||||
| with open(file_path, "ab") as f: | |||||
| f.write(write_event) | |||||
| @@ -1,100 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Log generator for scalars.""" | |||||
| import time | |||||
| import numpy as np | |||||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||||
| class ScalarsLogGenerator(LogGenerator): | |||||
| """ | |||||
| Log generator for scalars. | |||||
| This is a log generator writing scalars. User can use it to generate fake | |||||
| summary logs about scalar. | |||||
| """ | |||||
| def generate_event(self, values): | |||||
| """ | |||||
| Method for generating scalar event. | |||||
| Args: | |||||
| values (dict): A dict contains: | |||||
| { | |||||
| wall_time (float): Timestamp. | |||||
| step (int): Train step. | |||||
| value (float): Scalar value. | |||||
| tag (str): Tag name. | |||||
| } | |||||
| Returns: | |||||
| summary_pb2.Event. | |||||
| """ | |||||
| scalar_event = summary_pb2.Event() | |||||
| scalar_event.wall_time = values.get('wall_time') | |||||
| scalar_event.step = values.get('step') | |||||
| value = scalar_event.summary.value.add() | |||||
| value.tag = values.get('tag') | |||||
| value.scalar_value = values.get('value') | |||||
| return scalar_event | |||||
| def generate_log(self, file_path, steps_list, tag_name): | |||||
| """ | |||||
| Generate log for external calls. | |||||
| Args: | |||||
| file_path (str): Path to write logs. | |||||
| steps_list (list): A list consists of step. | |||||
| tag_name (str): Tag name. | |||||
| Returns: | |||||
| list[dict], generated scalar metadata. | |||||
| None, to be consistent with return value of ImageGenerator. | |||||
| """ | |||||
| scalars_metadata = [] | |||||
| for step in steps_list: | |||||
| scalar_metadata = dict() | |||||
| wall_time = time.time() | |||||
| value = np.random.rand() | |||||
| scalar_metadata.update({'wall_time': wall_time}) | |||||
| scalar_metadata.update({'step': step}) | |||||
| scalar_metadata.update({'value': value}) | |||||
| scalars_metadata.append(scalar_metadata) | |||||
| scalar_metadata.update({"tag": tag_name}) | |||||
| self._write_log_one_step(file_path, scalar_metadata) | |||||
| return scalars_metadata, None | |||||
| if __name__ == "__main__": | |||||
| scalars_log_generator = ScalarsLogGenerator() | |||||
| test_file_name = '%s.%s.%s' % ('scalar', 'summary', str(time.time())) | |||||
| test_steps = [1, 3, 5] | |||||
| test_tag = "test_scalar_tag_name" | |||||
| scalars_log_generator.generate_log(test_file_name, test_steps, test_tag) | |||||
| @@ -1,83 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """log operations module.""" | |||||
| import json | |||||
| import os | |||||
| import time | |||||
| from tests.ut.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator | |||||
| from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||||
| from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||||
| log_generators = { | |||||
| PluginNameEnum.GRAPH.value: GraphLogGenerator(), | |||||
| PluginNameEnum.IMAGE.value: ImagesLogGenerator(), | |||||
| PluginNameEnum.SCALAR.value: ScalarsLogGenerator() | |||||
| } | |||||
| class LogOperations: | |||||
| """Log Operations class.""" | |||||
| @staticmethod | |||||
| def generate_log(plugin_name, log_dir, log_settings, valid=True): | |||||
| """ | |||||
| Generate log. | |||||
| Args: | |||||
| plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'. | |||||
| log_dir (str): Log path to write log. | |||||
| log_settings (dict): Info about the log, e.g.: | |||||
| { | |||||
| current_time (int): Timestamp in summary file name, not necessary. | |||||
| graph_base_path (str): Path of graph_bas.json, necessary for `graph`. | |||||
| steps (list[int]): Steps for `image` and `scalar`, default is [1]. | |||||
| tag (str): Tag name, default is 'default_tag'. | |||||
| } | |||||
| valid (bool): If true, summary name will be valid. | |||||
| Returns: | |||||
| str, Summary log path. | |||||
| """ | |||||
| current_time = log_settings.get('time', int(time.time())) | |||||
| current_time = int(current_time) | |||||
| log_generator = log_generators.get(plugin_name) | |||||
| if valid: | |||||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time))) | |||||
| else: | |||||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time))) | |||||
| if plugin_name == PluginNameEnum.GRAPH.value: | |||||
| graph_base_path = log_settings.get('graph_base_path') | |||||
| with open(graph_base_path, 'r') as load_f: | |||||
| graph_dict = json.load(load_f) | |||||
| graph_dict = log_generator.generate_log(temp_path, graph_dict) | |||||
| return temp_path, graph_dict | |||||
| steps_list = log_settings.get('steps', [1]) | |||||
| tag_name = log_settings.get('tag', 'default_tag') | |||||
| metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name) | |||||
| return temp_path, metadata, values | |||||
| @staticmethod | |||||
| def get_log_generator(plugin_name): | |||||
| """Get log generator.""" | |||||
| return log_generators.get(plugin_name) | |||||
| @@ -1,59 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Description: This file is used for some common util. | |||||
| """ | |||||
| import os | |||||
| import shutil | |||||
| import time | |||||
| from urllib.parse import urlencode | |||||
| from mindinsight.datavisual.common.enums import DataManagerStatus | |||||
| def get_url(url, params): | |||||
| """ | |||||
| Concatenate the URL and params. | |||||
| Args: | |||||
| url (str): A link requested. For example, http://example.com. | |||||
| params (dict): A dict consists of params. For example, {'offset': 1, 'limit':'100}. | |||||
| Returns: | |||||
| str, like http://example.com?offset=1&limit=100 | |||||
| """ | |||||
| return url + '?' + urlencode(params) | |||||
| def delete_files_or_dirs(path_list): | |||||
| """Delete files or dirs in path_list.""" | |||||
| for path in path_list: | |||||
| if os.path.isdir(path): | |||||
| shutil.rmtree(path) | |||||
| else: | |||||
| os.remove(path) | |||||
| def check_loading_done(data_manager, time_limit=15): | |||||
| """If loading data for more than `time_limit` seconds, exit.""" | |||||
| start_time = time.time() | |||||
| while data_manager.status != DataManagerStatus.DONE.value: | |||||
| time_used = time.time() - start_time | |||||
| if time_used > time_limit: | |||||
| break | |||||
| time.sleep(0.1) | |||||
| continue | |||||
| @@ -14,6 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Import the mocked mindspore.""" | """Import the mocked mindspore.""" | ||||
| import sys | import sys | ||||
| from .collection.model import mindspore | |||||
| from ...utils import mindspore | |||||
| sys.modules['mindspore'] = mindspore | sys.modules['mindspore'] = mindspore | ||||
| @@ -1,21 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock MindSpore Interface.""" | |||||
| from .application.model_zoo.resnet import ResNet | |||||
| from .common.tensor import Tensor | |||||
| from .dataset import MindDataset | |||||
| from .nn import * | |||||
| from .train.callback import _ListCallback, Callback, RunContext, ModelCheckpoint, SummaryStep | |||||
| from .train.summary import SummaryRecord | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,24 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore ResNet class.""" | |||||
| from ...nn.cell import Cell | |||||
| class ResNet(Cell): | |||||
| """Mocked ResNet.""" | |||||
| def __init__(self): | |||||
| super(ResNet, self).__init__() | |||||
| self._cells = {} | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,30 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/common/tensor.py.""" | |||||
| import numpy as np | |||||
| class Tensor: | |||||
| """Mock the MindSpore Tensor class.""" | |||||
| def __init__(self, value=0): | |||||
| self._value = value | |||||
| def asnumpy(self): | |||||
| """Get value in numpy format.""" | |||||
| return np.array(self._value) | |||||
| def __repr__(self): | |||||
| return str(self.asnumpy()) | |||||
| @@ -1,21 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MindSpore Mock Interface""" | |||||
| def get_context(key): | |||||
| """Get key in context.""" | |||||
| context = {"device_id": 1} | |||||
| return context.get(key) | |||||
| @@ -1,16 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock mindspore.dataset.""" | |||||
| from .engine import MindDataset | |||||
| @@ -1,16 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock mindspore.dataset.engine.""" | |||||
| from .datasets import MindDataset, Dataset | |||||
| @@ -1,36 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/dataset/engine/datasets.py.""" | |||||
| class Dataset: | |||||
| """Mock the MindSpore Dataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_path=None): | |||||
| self.dataset_size = dataset_size | |||||
| self.dataset_path = dataset_path | |||||
| self.input = [] | |||||
| def get_dataset_size(self): | |||||
| """Mocked get_dataset_size.""" | |||||
| return self.dataset_size | |||||
| class MindDataset(Dataset): | |||||
| """Mock the MindSpore MindDataset class.""" | |||||
| def __init__(self, dataset_size=None, dataset_file=None): | |||||
| super(MindDataset, self).__init__(dataset_size) | |||||
| self.dataset_file = dataset_file | |||||
| @@ -1,22 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the mindspore.nn package.""" | |||||
| from .optim import Optimizer, Momentum | |||||
| from .loss.loss import SoftmaxCrossEntropyWithLogits, _Loss | |||||
| from .cell import Cell, WithLossCell, TrainOneStepWithLossScaleCell | |||||
| __all__ = ['Optimizer', 'Momentum', 'SoftmaxCrossEntropyWithLogits', | |||||
| '_Loss', 'Cell', 'WithLossCell', | |||||
| 'TrainOneStepWithLossScaleCell'] | |||||
| @@ -1,51 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/train/callback.py.""" | |||||
| class Cell: | |||||
| """Mock the Cell class.""" | |||||
| def __init__(self, auto_prefix=True, pips=None): | |||||
| if pips is None: | |||||
| pips = dict() | |||||
| self._auto_prefix = auto_prefix | |||||
| self._pips = pips | |||||
| @property | |||||
| def auto_prefix(self): | |||||
| """The property of auto_prefix.""" | |||||
| return self._auto_prefix | |||||
| @property | |||||
| def pips(self): | |||||
| """The property of pips.""" | |||||
| return self._pips | |||||
| class WithLossCell(Cell): | |||||
| """Mocked WithLossCell class.""" | |||||
| def __init__(self, backbone, loss_fn): | |||||
| super(WithLossCell, self).__init__(auto_prefix=False, pips=backbone.pips) | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| class TrainOneStepWithLossScaleCell(Cell): | |||||
| """Mocked TrainOneStepWithLossScaleCell.""" | |||||
| def __init__(self): | |||||
| super(TrainOneStepWithLossScaleCell, self).__init__() | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,39 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore SoftmaxCrossEntropyWithLogits class.""" | |||||
| from ..cell import Cell | |||||
| class _Loss(Cell): | |||||
| """Mocked _Loss.""" | |||||
| def __init__(self, reduction='mean'): | |||||
| super(_Loss, self).__init__() | |||||
| self.reduction = reduction | |||||
| def construct(self, base, target): | |||||
| """Mocked construct function.""" | |||||
| raise NotImplementedError | |||||
| class SoftmaxCrossEntropyWithLogits(_Loss): | |||||
| """Mocked SoftmaxCrossEntropyWithLogits.""" | |||||
| def __init__(self, weight=None): | |||||
| super(SoftmaxCrossEntropyWithLogits, self).__init__(weight) | |||||
| def construct(self, base, target): | |||||
| """Mocked construct.""" | |||||
| return 1 | |||||
| @@ -1,49 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/nn/optim.py.""" | |||||
| from .cell import Cell | |||||
| class Parameter: | |||||
| """Mock the MindSpore Parameter class.""" | |||||
| def __init__(self, learning_rate): | |||||
| self._name = "Parameter" | |||||
| self.default_input = learning_rate | |||||
| @property | |||||
| def name(self): | |||||
| """The property of name.""" | |||||
| return self._name | |||||
| def __repr__(self): | |||||
| format_str = 'Parameter (name={name})' | |||||
| return format_str.format(name=self._name) | |||||
| class Optimizer(Cell): | |||||
| """Mock the MindSpore Optimizer class.""" | |||||
| def __init__(self, learning_rate): | |||||
| super(Optimizer, self).__init__() | |||||
| self.learning_rate = Parameter(learning_rate) | |||||
| class Momentum(Optimizer): | |||||
| """Mock the MindSpore Momentum class.""" | |||||
| def __init__(self, learning_rate): | |||||
| super(Momentum, self).__init__(learning_rate) | |||||
| self.dynamic_lr = False | |||||
| @@ -1,17 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock MindSpore wrap package.""" | |||||
| from .loss_scale import TrainOneStepWithLossScaleCell | |||||
| from .cell_wrapper import WithLossCell | |||||
| @@ -1,25 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock MindSpore cell_wrapper.py.""" | |||||
| from ..cell import Cell | |||||
| class WithLossCell(Cell): | |||||
| """Mock the WithLossCell class.""" | |||||
| def __init__(self, backbone, loss_fn): | |||||
| super(WithLossCell, self).__init__() | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| @@ -1,29 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock MindSpore loss_scale.py.""" | |||||
| from ..cell import Cell | |||||
| class TrainOneStepWithLossScaleCell(Cell): | |||||
| """Mock the TrainOneStepWithLossScaleCell class.""" | |||||
| def __init__(self, network, optimizer): | |||||
| super(TrainOneStepWithLossScaleCell, self).__init__() | |||||
| self.network = network | |||||
| self.optimizer = optimizer | |||||
| def construct(self, data, label): | |||||
| """Mock the construct method.""" | |||||
| raise NotImplementedError | |||||
| @@ -1,14 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -1,84 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mock the MindSpore mindspore/train/callback.py.""" | |||||
| import os | |||||
| class RunContext: | |||||
| """Mock the RunContext class.""" | |||||
| def __init__(self, original_args=None): | |||||
| self._original_args = original_args | |||||
| self._stop_requested = False | |||||
| def original_args(self): | |||||
| """Mock original_args.""" | |||||
| return self._original_args | |||||
| def stop_requested(self): | |||||
| """Mock stop_requested method.""" | |||||
| return self._stop_requested | |||||
| class Callback: | |||||
| """Mock the Callback class.""" | |||||
| def __init__(self): | |||||
| pass | |||||
| def begin(self, run_context): | |||||
| """Called once before network training.""" | |||||
| def epoch_begin(self, run_context): | |||||
| """Called before each epoch begin.""" | |||||
| class _ListCallback(Callback): | |||||
| """Mock the _ListCallabck class.""" | |||||
| def __init__(self, callbacks): | |||||
| super(_ListCallback, self).__init__() | |||||
| self._callbacks = callbacks | |||||
| class ModelCheckpoint(Callback): | |||||
| """Mock the ModelCheckpoint class.""" | |||||
| def __init__(self, prefix='CKP', directory=None, config=None): | |||||
| super(ModelCheckpoint, self).__init__() | |||||
| self._prefix = prefix | |||||
| self._directory = directory | |||||
| self._config = config | |||||
| self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt') | |||||
| @property | |||||
| def model_file_name(self): | |||||
| """Get the file name of model.""" | |||||
| return self._model_file_name | |||||
| @property | |||||
| def latest_ckpt_file_name(self): | |||||
| """Get the latest file name fo checkpoint.""" | |||||
| return self._latest_ckpt_file_name | |||||
| class SummaryStep(Callback): | |||||
| """Mock the SummaryStep class.""" | |||||
| def __init__(self, summary, flush_step=10): | |||||
| super(SummaryStep, self).__init__() | |||||
| self._sumamry = summary | |||||
| self._flush_step = flush_step | |||||
| self.summary_file_name = summary.full_file_name | |||||
| @@ -1,18 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MindSpore Mock Interface""" | |||||
| from .summary_record import SummaryRecord | |||||
| __all__ = ["SummaryRecord"] | |||||
| @@ -1,38 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MindSpore Mock Interface""" | |||||
| import os | |||||
| import time | |||||
| class SummaryRecord: | |||||
| """Mock the MindSpore SummaryRecord class.""" | |||||
| def __init__(self, | |||||
| log_dir: str, | |||||
| file_prefix: str = "events.", | |||||
| file_suffix: str = ".MS", | |||||
| create_time=int(time.time())): | |||||
| self.log_dir = log_dir | |||||
| self.prefix = file_prefix | |||||
| self.suffix = file_suffix | |||||
| file_name = file_prefix + 'summary.' + str(create_time) + file_suffix | |||||
| self.full_file_name = os.path.join(log_dir, file_name) | |||||
| def flush(self): | |||||
| """Mock flush method.""" | |||||
| def close(self): | |||||
| """Mock close method.""" | |||||
| @@ -16,21 +16,21 @@ | |||||
| import os | import os | ||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| from unittest import mock, TestCase | |||||
| from unittest import TestCase, mock | |||||
| from unittest.mock import MagicMock | from unittest.mock import MagicMock | ||||
| from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ | |||||
| AnalyzeObject | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageLogError, LineageGetModelFileError, MindInsightException | |||||
| from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError, | |||||
| MindInsightException) | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.dataset.engine import MindDataset, Dataset | |||||
| from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \ | |||||
| SoftmaxCrossEntropyWithLogits | |||||
| from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep | |||||
| from mindspore.dataset.engine import Dataset, MindDataset | |||||
| from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell | |||||
| from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep | |||||
| from mindspore.train.summary import SummaryRecord | from mindspore.train.summary import SummaryRecord | ||||
| @mock.patch('builtins.open') | |||||
| @mock.patch('os.makedirs') | |||||
| class TestModelLineage(TestCase): | class TestModelLineage(TestCase): | ||||
| """Test TrainLineage and EvalLineage class in model_lineage.py.""" | """Test TrainLineage and EvalLineage class in model_lineage.py.""" | ||||
| @@ -51,23 +51,19 @@ class TestModelLineage(TestCase): | |||||
| cls.summary_log_path = '/path/to/summary_log' | cls.summary_log_path = '/path/to/summary_log' | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | ||||
| def test_summary_record_exception(self, mock_validate_summary): | |||||
| def test_summary_record_exception(self, *args): | |||||
| """Test SummaryRecord with exception.""" | """Test SummaryRecord with exception.""" | ||||
| mock_validate_summary.return_value = None | |||||
| args[0].return_value = None | |||||
| summary_record = self.my_summary_record(self.summary_log_path) | summary_record = self.my_summary_record(self.summary_log_path) | ||||
| with self.assertRaises(MindInsightException) as context: | with self.assertRaises(MindInsightException) as context: | ||||
| self.my_train_module(summary_record=summary_record, raise_exception=1) | self.my_train_module(summary_record=summary_record, raise_exception=1) | ||||
| self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception)) | self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception)) | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'LineageSummary.record_dataset_graph') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'validate_summary_record') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'AnalyzeObject.get_optimizer_by_network') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | ||||
| def test_begin(self, *args): | def test_begin(self, *args): | ||||
| """Test TrainLineage.begin method.""" | """Test TrainLineage.begin method.""" | ||||
| @@ -82,14 +78,10 @@ class TestModelLineage(TestCase): | |||||
| args[4].assert_called() | args[4].assert_called() | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'LineageSummary.record_dataset_graph') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'validate_summary_record') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'AnalyzeObject.get_optimizer_by_network') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' | |||||
| 'AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | ||||
| def test_begin_error(self, *args): | def test_begin_error(self, *args): | ||||
| """Test TrainLineage.begin method.""" | """Test TrainLineage.begin method.""" | ||||
| @@ -122,15 +114,11 @@ class TestModelLineage(TestCase): | |||||
| train_lineage.begin(self.my_run_context(run_context)) | train_lineage.begin(self.my_run_context(run_context)) | ||||
| self.assertTrue('The parameter optimizer is invalid.' in str(context.exception)) | self.assertTrue('The parameter optimizer is invalid.' in str(context.exception)) | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | ||||
| @mock.patch('builtins.float') | @mock.patch('builtins.float') | ||||
| @@ -150,23 +138,19 @@ class TestModelLineage(TestCase): | |||||
| args[6].assert_called() | args[6].assert_called() | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | ||||
| def test_train_end_exception(self, mock_validate_summary): | |||||
| def test_train_end_exception(self, *args): | |||||
| """Test TrainLineage.end method when exception.""" | """Test TrainLineage.end method when exception.""" | ||||
| mock_validate_summary.return_value = True | |||||
| args[0].return_value = True | |||||
| train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True) | train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True) | ||||
| with self.assertRaises(Exception) as context: | with self.assertRaises(Exception) as context: | ||||
| train_lineage.end(self.run_context) | train_lineage.end(self.run_context) | ||||
| self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception)) | self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception)) | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | ||||
| @mock.patch('builtins.float') | @mock.patch('builtins.float') | ||||
| @@ -186,15 +170,11 @@ class TestModelLineage(TestCase): | |||||
| train_lineage.end(self.my_run_context(self.run_context)) | train_lineage.end(self.my_run_context(self.run_context)) | ||||
| self.assertTrue('End error in TrainLineage:' in str(context.exception)) | self.assertTrue('End error in TrainLineage:' in str(context.exception)) | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') | ||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch( | |||||
| 'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer') | |||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') | ||||
| @mock.patch('builtins.float') | @mock.patch('builtins.float') | ||||
| @@ -218,9 +198,9 @@ class TestModelLineage(TestCase): | |||||
| self.assertTrue('End error in TrainLineage:' in str(context.exception)) | self.assertTrue('End error in TrainLineage:' in str(context.exception)) | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | ||||
| def test_eval_exception_train_id_none(self, mock_validate_summary): | |||||
| def test_eval_exception_train_id_none(self, *args): | |||||
| """Test EvalLineage.end method with initialization error.""" | """Test EvalLineage.end method with initialization error.""" | ||||
| mock_validate_summary.return_value = True | |||||
| args[0].return_value = True | |||||
| with self.assertRaises(MindInsightException) as context: | with self.assertRaises(MindInsightException) as context: | ||||
| self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2) | self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2) | ||||
| self.assertTrue('Invalid value for raise_exception.' in str(context.exception)) | self.assertTrue('Invalid value for raise_exception.' in str(context.exception)) | ||||
| @@ -242,9 +222,9 @@ class TestModelLineage(TestCase): | |||||
| args[0].assert_called() | args[0].assert_called() | ||||
| @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') | ||||
| def test_eval_end_except_run_context(self, mock_validate_summary): | |||||
| def test_eval_end_except_run_context(self, *args): | |||||
| """Test EvalLineage.end method when run_context is invalid..""" | """Test EvalLineage.end method when run_context is invalid..""" | ||||
| mock_validate_summary.return_value = True | |||||
| args[0].return_value = True | |||||
| eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True) | eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True) | ||||
| with self.assertRaises(Exception) as context: | with self.assertRaises(Exception) as context: | ||||
| eval_lineage.end(self.run_context) | eval_lineage.end(self.run_context) | ||||
| @@ -284,8 +264,9 @@ class TestModelLineage(TestCase): | |||||
| eval_lineage.end(self.my_run_context(self.run_context)) | eval_lineage.end(self.my_run_context(self.run_context)) | ||||
| self.assertTrue('End error in EvalLineage' in str(context.exception)) | self.assertTrue('End error in EvalLineage' in str(context.exception)) | ||||
| def test_epoch_is_zero(self): | |||||
| def test_epoch_is_zero(self, *args): | |||||
| """Test TrainLineage.end method.""" | """Test TrainLineage.end method.""" | ||||
| args[0].return_value = None | |||||
| run_context = self.run_context | run_context = self.run_context | ||||
| run_context['epoch_num'] = 0 | run_context['epoch_num'] = 0 | ||||
| with self.assertRaises(MindInsightException): | with self.assertRaises(MindInsightException): | ||||
| @@ -345,7 +326,7 @@ class TestAnalyzer(TestCase): | |||||
| ) | ) | ||||
| res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') | res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') | ||||
| res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') | res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') | ||||
| assert res1 == {'step_num': 10, | |||||
| assert res1 == {'step_num': 10, | |||||
| 'train_dataset_path': '/path/to', | 'train_dataset_path': '/path/to', | ||||
| 'train_dataset_size': 50, | 'train_dataset_size': 50, | ||||
| 'epoch': 2} | 'epoch': 2} | ||||
| @@ -20,23 +20,44 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||||
| from mindinsight.lineagemgr.common.path_parser import SummaryPathParser | from mindinsight.lineagemgr.common.path_parser import SummaryPathParser | ||||
| MOCK_SUMMARY_DIRS = [ | |||||
| { | |||||
| 'relative_path': './relative_path0' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './relative_path1' | |||||
| } | |||||
| ] | |||||
| MOCK_SUMMARIES = [ | |||||
| { | |||||
| 'file_name': 'file0', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file0_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| } | |||||
| ] | |||||
| class TestSummaryPathParser(TestCase): | class TestSummaryPathParser(TestCase): | ||||
| """Test the class of SummaryPathParser.""" | """Test the class of SummaryPathParser.""" | ||||
| @mock.patch.object(SummaryWatcher, 'list_summary_directories') | @mock.patch.object(SummaryWatcher, 'list_summary_directories') | ||||
| def test_get_summary_dirs(self, *args): | def test_get_summary_dirs(self, *args): | ||||
| """Test the function of get_summary_dirs.""" | """Test the function of get_summary_dirs.""" | ||||
| args[0].return_value = [ | |||||
| { | |||||
| 'relative_path': './relative_path0' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './relative_path1' | |||||
| } | |||||
| ] | |||||
| args[0].return_value = MOCK_SUMMARY_DIRS | |||||
| expected_result = [ | expected_result = [ | ||||
| '/path/to/base/relative_path0', | '/path/to/base/relative_path0', | ||||
| @@ -54,24 +75,7 @@ class TestSummaryPathParser(TestCase): | |||||
| @mock.patch.object(SummaryWatcher, 'list_summaries') | @mock.patch.object(SummaryWatcher, 'list_summaries') | ||||
| def test_get_latest_lineage_summary(self, *args): | def test_get_latest_lineage_summary(self, *args): | ||||
| """Test the function of get_latest_lineage_summary.""" | """Test the function of get_latest_lineage_summary.""" | ||||
| args[0].return_value = [ | |||||
| { | |||||
| 'file_name': 'file0', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file0_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| } | |||||
| ] | |||||
| args[0].return_value = MOCK_SUMMARIES | |||||
| summary_dir = '/path/to/summary_dir' | summary_dir = '/path/to/summary_dir' | ||||
| result = SummaryPathParser.get_latest_lineage_summary(summary_dir) | result = SummaryPathParser.get_latest_lineage_summary(summary_dir) | ||||
| self.assertEqual('/path/to/summary_dir/file1_lineage', result) | self.assertEqual('/path/to/summary_dir/file1_lineage', result) | ||||
| @@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase): | |||||
| @mock.patch.object(SummaryWatcher, 'list_summary_directories') | @mock.patch.object(SummaryWatcher, 'list_summary_directories') | ||||
| def test_get_latest_lineage_summaries(self, *args): | def test_get_latest_lineage_summaries(self, *args): | ||||
| """Test the function of get_latest_lineage_summaries.""" | """Test the function of get_latest_lineage_summaries.""" | ||||
| args[0].return_value = [ | |||||
| { | |||||
| 'relative_path': './relative_path0' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './' | |||||
| }, | |||||
| { | |||||
| 'relative_path': './relative_path1' | |||||
| } | |||||
| ] | |||||
| args[1].return_value = [ | |||||
| { | |||||
| 'file_name': 'file0', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file0_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031970) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| }, | |||||
| { | |||||
| 'file_name': 'file1_lineage', | |||||
| 'create_time': datetime.fromtimestamp(1582031971) | |||||
| } | |||||
| ] | |||||
| args[0].return_value = MOCK_SUMMARY_DIRS | |||||
| args[1].return_value = MOCK_SUMMARIES | |||||
| expected_result = [ | expected_result = [ | ||||
| '/path/to/base/relative_path0/file1_lineage', | '/path/to/base/relative_path0/file1_lineage', | ||||
| @@ -15,38 +15,31 @@ | |||||
| """Test the validate module.""" | """Test the validate module.""" | ||||
| from unittest import TestCase | from unittest import TestCase | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageParamValueError, LineageParamTypeError | |||||
| from mindinsight.lineagemgr.common.validator.model_parameter import \ | |||||
| SearchModelConditionParameter | |||||
| from mindinsight.lineagemgr.common.validator.validate import \ | |||||
| validate_search_model_condition | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError | |||||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | |||||
| from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition | |||||
| from mindinsight.utils.exceptions import MindInsightException | from mindinsight.utils.exceptions import MindInsightException | ||||
| class TestValidateSearchModelCondition(TestCase): | class TestValidateSearchModelCondition(TestCase): | ||||
| """Test the mothod of validate_search_model_condition.""" | """Test the mothod of validate_search_model_condition.""" | ||||
| def test_validate_search_model_condition(self): | |||||
| """Test the mothod of validate_search_model_condition.""" | |||||
| def test_validate_search_model_condition_param_type_error(self): | |||||
| """Test the mothod of validate_search_model_condition with LineageParamTypeError.""" | |||||
| condition = { | condition = { | ||||
| 'summary_dir': 'xxx' | 'summary_dir': 'xxx' | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| LineageParamTypeError, | |||||
| self._assert_raise_of_lineage_param_type_error( | |||||
| 'The search_condition element summary_dir should be dict.', | 'The search_condition element summary_dir should be dict.', | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| def test_validate_search_model_condition_param_value_error(self): | |||||
| """Test the mothod of validate_search_model_condition with LineageParamValueError.""" | |||||
| condition = { | condition = { | ||||
| 'xxx': 'xxx' | 'xxx': 'xxx' | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| LineageParamValueError, | |||||
| self._assert_raise_of_lineage_param_value_error( | |||||
| 'The search attribute not supported.', | 'The search attribute not supported.', | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -55,22 +48,38 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'xxx': 'xxx' | 'xxx': 'xxx' | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| LineageParamValueError, | |||||
| self._assert_raise_of_lineage_param_value_error( | |||||
| "The compare condition should be in", | "The compare condition should be in", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| condition = { | |||||
| 1: { | |||||
| "ge": 8.0 | |||||
| } | |||||
| } | |||||
| self._assert_raise_of_lineage_param_value_error( | |||||
| "The search attribute not supported.", | |||||
| condition | |||||
| ) | |||||
| condition = { | |||||
| 'metric_': { | |||||
| "ge": 8.0 | |||||
| } | |||||
| } | |||||
| self._assert_raise_of_lineage_param_value_error( | |||||
| "The search attribute not supported.", | |||||
| condition | |||||
| ) | |||||
| def test_validate_search_model_condition_mindinsight_exception_1(self): | |||||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||||
| condition = { | condition = { | ||||
| "offset": 100001 | "offset": 100001 | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "Invalid input offset. 0 <= offset <= 100000", | "Invalid input offset. 0 <= offset <= 100000", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -80,11 +89,9 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| }, | }, | ||||
| 'limit': 10 | 'limit': 10 | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter summary_dir is invalid. It should be a dict and the value should be a string", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter summary_dir is invalid. It should be a dict and " | |||||
| "the value should be a string", | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -93,11 +100,9 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'in': 1.0 | 'in': 1.0 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter learning_rate is invalid. It should be a dict and " | |||||
| "the value should be a float or a integer", | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -106,24 +111,22 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'lt': True | 'lt': True | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter learning_rate is invalid. It should be a dict and " | |||||
| "the value should be a float or a integer", | |||||
| condition | condition | ||||
| ) | ) | ||||
| def test_validate_search_model_condition_mindinsight_exception_2(self): | |||||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||||
| condition = { | condition = { | ||||
| 'learning_rate': { | 'learning_rate': { | ||||
| 'gt': [1.0] | 'gt': [1.0] | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter learning_rate is invalid. It should be a dict and " | |||||
| "the value should be a float or a integer", | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -132,11 +135,9 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'ge': 1 | 'ge': 1 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter loss_function is invalid. It should be a dict and the value should be a string", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter loss_function is invalid. It should be a dict and " | |||||
| "the value should be a string", | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -145,12 +146,9 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'in': 2 | 'in': 2 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter train_dataset_count is invalid. It should be a dict " | "The parameter train_dataset_count is invalid. It should be a dict " | ||||
| "and the value should be a integer between 0", | "and the value should be a integer between 0", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -162,14 +160,14 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'eq': 'xxx' | 'eq': 'xxx' | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter network is invalid. It should be a dict and the value should be a string", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter network is invalid. It should be a dict and " | |||||
| "the value should be a string", | |||||
| condition | condition | ||||
| ) | ) | ||||
| def test_validate_search_model_condition_mindinsight_exception_3(self): | |||||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||||
| condition = { | condition = { | ||||
| 'batch_size': { | 'batch_size': { | ||||
| 'lt': 2, | 'lt': 2, | ||||
| @@ -179,11 +177,8 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'eq': 222 | 'eq': 222 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter batch_size is invalid. It should be a non-negative integer.", | "The parameter batch_size is invalid. It should be a non-negative integer.", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -192,12 +187,9 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'lt': -2 | 'lt': -2 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter test_dataset_count is invalid. It should be a dict " | "The parameter test_dataset_count is invalid. It should be a dict " | ||||
| "and the value should be a integer between 0", | "and the value should be a integer between 0", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -206,11 +198,8 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| 'lt': False | 'lt': False | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter epoch is invalid. It should be a positive integer.", | "The parameter epoch is invalid. It should be a positive integer.", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| @@ -219,65 +208,79 @@ class TestValidateSearchModelCondition(TestCase): | |||||
| "ge": "" | "ge": "" | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| "The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter learning_rate is invalid. It should be a dict and " | |||||
| "the value should be a float or a integer", | |||||
| condition | condition | ||||
| ) | ) | ||||
| def test_validate_search_model_condition_mindinsight_exception_4(self): | |||||
| """Test the mothod of validate_search_model_condition with MindinsightException.""" | |||||
| condition = { | condition = { | ||||
| "train_dataset_count": { | "train_dataset_count": { | ||||
| "ge": 8.0 | "ge": 8.0 | ||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| MindInsightException, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter train_dataset_count is invalid. It should be a dict " | "The parameter train_dataset_count is invalid. It should be a dict " | ||||
| "and the value should be a integer between 0", | "and the value should be a integer between 0", | ||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | condition | ||||
| ) | ) | ||||
| condition = { | condition = { | ||||
| 1: { | |||||
| "ge": 8.0 | |||||
| 'metric_attribute': { | |||||
| 'ge': 'xxx' | |||||
| } | } | ||||
| } | } | ||||
| self.assertRaisesRegex( | |||||
| LineageParamValueError, | |||||
| "The search attribute not supported.", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| self._assert_raise_of_mindinsight_exception( | |||||
| "The parameter metric_attribute is invalid. " | |||||
| "It should be a dict and the value should be a float or a integer", | |||||
| condition | condition | ||||
| ) | ) | ||||
| condition = { | |||||
| 'metric_': { | |||||
| "ge": 8.0 | |||||
| } | |||||
| } | |||||
| LineageParamValueError('The search attribute not supported.') | |||||
| self.assertRaisesRegex( | |||||
| LineageParamValueError, | |||||
| "The search attribute not supported.", | |||||
| validate_search_model_condition, | |||||
| SearchModelConditionParameter, | |||||
| condition | |||||
| ) | |||||
| def _assert_raise(self, exception, msg, condition): | |||||
| """ | |||||
| Assert raise by unittest. | |||||
| condition = { | |||||
| 'metric_attribute': { | |||||
| 'ge': 'xxx' | |||||
| } | |||||
| } | |||||
| Args: | |||||
| exception (Type): Exception class expected to be raised. | |||||
| msg (msg): Expected error message. | |||||
| condition (dict): The parameter of search condition. | |||||
| """ | |||||
| self.assertRaisesRegex( | self.assertRaisesRegex( | ||||
| MindInsightException, | |||||
| "The parameter metric_attribute is invalid. " | |||||
| "It should be a dict and the value should be a float or a integer", | |||||
| exception, | |||||
| msg, | |||||
| validate_search_model_condition, | validate_search_model_condition, | ||||
| SearchModelConditionParameter, | SearchModelConditionParameter, | ||||
| condition | condition | ||||
| ) | ) | ||||
| def _assert_raise_of_mindinsight_exception(self, msg, condition): | |||||
| """ | |||||
| Assert raise of MindinsightException by unittest. | |||||
| Args: | |||||
| msg (msg): Expected error message. | |||||
| condition (dict): The parameter of search condition. | |||||
| """ | |||||
| self._assert_raise(MindInsightException, msg, condition) | |||||
| def _assert_raise_of_lineage_param_value_error(self, msg, condition): | |||||
| """ | |||||
| Assert raise of LineageParamValueError by unittest. | |||||
| Args: | |||||
| msg (msg): Expected error message. | |||||
| condition (dict): The parameter of search condition. | |||||
| """ | |||||
| self._assert_raise(LineageParamValueError, msg, condition) | |||||
| def _assert_raise_of_lineage_param_type_error(self, msg, condition): | |||||
| """ | |||||
| Assert raise of LineageParamTypeError by unittest. | |||||
| Args: | |||||
| msg (msg): Expected error message. | |||||
| condition (dict): The parameter of search condition. | |||||
| """ | |||||
| self._assert_raise(LineageParamTypeError, msg, condition) | |||||
| @@ -15,6 +15,8 @@ | |||||
| """The event data in querier test.""" | """The event data in querier test.""" | ||||
| import json | import json | ||||
| from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE | |||||
| EVENT_TRAIN_DICT_0 = { | EVENT_TRAIN_DICT_0 = { | ||||
| 'wall_time': 1581499557.7017336, | 'wall_time': 1581499557.7017336, | ||||
| 'train_lineage': { | 'train_lineage': { | ||||
| @@ -373,49 +375,4 @@ EVENT_DATASET_DICT_0 = { | |||||
| } | } | ||||
| } | } | ||||
| DATASET_DICT_0 = { | |||||
| 'op_type': 'BatchDataset', | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'num_parallel_workers': None, | |||||
| 'drop_remainder': True, | |||||
| 'batch_size': 10, | |||||
| 'children': [ | |||||
| { | |||||
| 'op_type': 'MapDataset', | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'num_parallel_workers': None, | |||||
| 'input_columns': [ | |||||
| 'label' | |||||
| ], | |||||
| 'output_columns': [ | |||||
| None | |||||
| ], | |||||
| 'operations': [ | |||||
| { | |||||
| 'tensor_op_module': 'minddata.transforms.c_transforms', | |||||
| 'tensor_op_name': 'OneHot', | |||||
| 'num_classes': 10 | |||||
| } | |||||
| ], | |||||
| 'children': [ | |||||
| { | |||||
| 'op_type': 'MnistDataset', | |||||
| 'shard_id': None, | |||||
| 'num_shards': None, | |||||
| 'op_module': 'minddata.dataengine.datasets', | |||||
| 'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData', | |||||
| 'num_parallel_workers': None, | |||||
| 'shuffle': None, | |||||
| 'num_samples': 100, | |||||
| 'sampler': { | |||||
| 'sampler_module': 'minddata.dataengine.samplers', | |||||
| 'sampler_name': 'RandomSampler', | |||||
| 'replacement': True, | |||||
| 'num_samples': 100 | |||||
| }, | |||||
| 'children': [] | |||||
| } | |||||
| ] | |||||
| } | |||||
| ] | |||||
| } | |||||
| DATASET_DICT_0 = SERIALIZED_PIPELINE | |||||
| @@ -18,12 +18,12 @@ from unittest import TestCase, mock | |||||
| from google.protobuf.json_format import ParseDict | from google.protobuf.json_format import ParseDict | ||||
| import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 | import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageQuerierParamException, LineageParamTypeError, \ | |||||
| LineageSummaryAnalyzeException, LineageSummaryParseException | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException, | |||||
| LineageSummaryAnalyzeException, | |||||
| LineageSummaryParseException) | |||||
| from mindinsight.lineagemgr.querier.querier import Querier | from mindinsight.lineagemgr.querier.querier import Querier | ||||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ | |||||
| LineageInfo | |||||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | |||||
| from . import event_data | from . import event_data | ||||
| @@ -140,6 +140,98 @@ def get_lineage_infos(): | |||||
| return lineage_infos | return lineage_infos | ||||
| LINEAGE_INFO_0 = { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| LINEAGE_INFO_1 = { | |||||
| 'summary_dir': '/path/to/summary1', | |||||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||||
| 'metric': event_data.METRIC_1, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| LINEAGE_FILTRATION_0 = create_filtration_result( | |||||
| '/path/to/summary0', | |||||
| event_data.EVENT_TRAIN_DICT_0, | |||||
| event_data.EVENT_EVAL_DICT_0, | |||||
| event_data.METRIC_0, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_1 = create_filtration_result( | |||||
| '/path/to/summary1', | |||||
| event_data.EVENT_TRAIN_DICT_1, | |||||
| event_data.EVENT_EVAL_DICT_1, | |||||
| event_data.METRIC_1, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_2 = create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_3 = create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_4 = create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_5 = { | |||||
| "summary_dir": '/path/to/summary5', | |||||
| "loss_function": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||||
| "test_dataset_path": None, | |||||
| "test_dataset_count": None, | |||||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||||
| "learning_rate": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||||
| "metric": {}, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| } | |||||
| LINEAGE_FILTRATION_6 = { | |||||
| "summary_dir": '/path/to/summary6', | |||||
| "loss_function": None, | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": None, | |||||
| "test_dataset_path": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||||
| "test_dataset_count": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||||
| "network": None, | |||||
| "optimizer": None, | |||||
| "learning_rate": None, | |||||
| "epoch": None, | |||||
| "batch_size": None, | |||||
| "loss": None, | |||||
| "model_size": None, | |||||
| "metric": event_data.METRIC_5, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| } | |||||
| class TestQuerier(TestCase): | class TestQuerier(TestCase): | ||||
| """Test the class of `Querier`.""" | """Test the class of `Querier`.""" | ||||
| @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') | ||||
| @@ -169,31 +261,13 @@ class TestQuerier(TestCase): | |||||
| def test_get_summary_lineage_success_1(self): | def test_get_summary_lineage_success_1(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [ | |||||
| { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| ] | |||||
| expected_result = [LINEAGE_INFO_0] | |||||
| result = self.single_querier.get_summary_lineage() | result = self.single_querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | self.assertListEqual(expected_result, result) | ||||
| def test_get_summary_lineage_success_2(self): | def test_get_summary_lineage_success_2(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [ | |||||
| { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': | |||||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][ | |||||
| 'valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| ] | |||||
| expected_result = [LINEAGE_INFO_0] | |||||
| result = self.single_querier.get_summary_lineage( | result = self.single_querier.get_summary_lineage( | ||||
| summary_dir='/path/to/summary0' | summary_dir='/path/to/summary0' | ||||
| ) | ) | ||||
| @@ -216,20 +290,8 @@ class TestQuerier(TestCase): | |||||
| def test_get_summary_lineage_success_4(self): | def test_get_summary_lineage_success_4(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [ | expected_result = [ | ||||
| { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': '/path/to/summary1', | |||||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||||
| 'metric': event_data.METRIC_1, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| }, | |||||
| LINEAGE_INFO_0, | |||||
| LINEAGE_INFO_1, | |||||
| { | { | ||||
| 'summary_dir': '/path/to/summary2', | 'summary_dir': '/path/to/summary2', | ||||
| **event_data.EVENT_TRAIN_DICT_2['train_lineage'], | **event_data.EVENT_TRAIN_DICT_2['train_lineage'], | ||||
| @@ -274,15 +336,7 @@ class TestQuerier(TestCase): | |||||
| def test_get_summary_lineage_success_5(self): | def test_get_summary_lineage_success_5(self): | ||||
| """Test the success of get_summary_lineage.""" | """Test the success of get_summary_lineage.""" | ||||
| expected_result = [ | |||||
| { | |||||
| 'summary_dir': '/path/to/summary1', | |||||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||||
| 'metric': event_data.METRIC_1, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| ] | |||||
| expected_result = [LINEAGE_INFO_1] | |||||
| result = self.multi_querier.get_summary_lineage( | result = self.multi_querier.get_summary_lineage( | ||||
| summary_dir='/path/to/summary1' | summary_dir='/path/to/summary1' | ||||
| ) | ) | ||||
| @@ -341,20 +395,8 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| create_filtration_result( | |||||
| '/path/to/summary1', | |||||
| event_data.EVENT_TRAIN_DICT_1, | |||||
| event_data.EVENT_EVAL_DICT_1, | |||||
| event_data.METRIC_1, | |||||
| event_data.DATASET_DICT_0, | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_1, | |||||
| LINEAGE_FILTRATION_2 | |||||
| ], | ], | ||||
| 'count': 2, | 'count': 2, | ||||
| } | } | ||||
| @@ -377,20 +419,8 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_2, | |||||
| LINEAGE_FILTRATION_3 | |||||
| ], | ], | ||||
| 'count': 2, | 'count': 2, | ||||
| } | } | ||||
| @@ -405,20 +435,8 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ) | |||||
| LINEAGE_FILTRATION_2, | |||||
| LINEAGE_FILTRATION_3 | |||||
| ], | ], | ||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| @@ -429,82 +447,13 @@ class TestQuerier(TestCase): | |||||
| """Test the success of filter_summary_lineage.""" | """Test the success of filter_summary_lineage.""" | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| create_filtration_result( | |||||
| '/path/to/summary0', | |||||
| event_data.EVENT_TRAIN_DICT_0, | |||||
| event_data.EVENT_EVAL_DICT_0, | |||||
| event_data.METRIC_0, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary1', | |||||
| event_data.EVENT_TRAIN_DICT_1, | |||||
| event_data.EVENT_EVAL_DICT_1, | |||||
| event_data.METRIC_1, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| { | |||||
| "summary_dir": '/path/to/summary5', | |||||
| "loss_function": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||||
| "test_dataset_path": None, | |||||
| "test_dataset_count": None, | |||||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||||
| "learning_rate": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||||
| "metric": {}, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| }, | |||||
| { | |||||
| "summary_dir": '/path/to/summary6', | |||||
| "loss_function": None, | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": None, | |||||
| "test_dataset_path": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||||
| "test_dataset_count": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||||
| "network": None, | |||||
| "optimizer": None, | |||||
| "learning_rate": None, | |||||
| "epoch": None, | |||||
| "batch_size": None, | |||||
| "loss": None, | |||||
| "model_size": None, | |||||
| "metric": event_data.METRIC_5, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| } | |||||
| LINEAGE_FILTRATION_0, | |||||
| LINEAGE_FILTRATION_1, | |||||
| LINEAGE_FILTRATION_2, | |||||
| LINEAGE_FILTRATION_3, | |||||
| LINEAGE_FILTRATION_4, | |||||
| LINEAGE_FILTRATION_5, | |||||
| LINEAGE_FILTRATION_6 | |||||
| ], | ], | ||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| @@ -519,15 +468,7 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | |||||
| create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| ], | |||||
| 'object': [LINEAGE_FILTRATION_4], | |||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| @@ -541,82 +482,13 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| create_filtration_result( | |||||
| '/path/to/summary0', | |||||
| event_data.EVENT_TRAIN_DICT_0, | |||||
| event_data.EVENT_EVAL_DICT_0, | |||||
| event_data.METRIC_0, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| { | |||||
| "summary_dir": '/path/to/summary5', | |||||
| "loss_function": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||||
| "test_dataset_path": None, | |||||
| "test_dataset_count": None, | |||||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||||
| "learning_rate": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||||
| "metric": {}, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| }, | |||||
| create_filtration_result( | |||||
| '/path/to/summary1', | |||||
| event_data.EVENT_TRAIN_DICT_1, | |||||
| event_data.EVENT_EVAL_DICT_1, | |||||
| event_data.METRIC_1, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| { | |||||
| "summary_dir": '/path/to/summary6', | |||||
| "loss_function": None, | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": None, | |||||
| "test_dataset_path": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||||
| "test_dataset_count": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||||
| "network": None, | |||||
| "optimizer": None, | |||||
| "learning_rate": None, | |||||
| "epoch": None, | |||||
| "batch_size": None, | |||||
| "loss": None, | |||||
| "model_size": None, | |||||
| "metric": event_data.METRIC_5, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| } | |||||
| LINEAGE_FILTRATION_0, | |||||
| LINEAGE_FILTRATION_5, | |||||
| LINEAGE_FILTRATION_1, | |||||
| LINEAGE_FILTRATION_2, | |||||
| LINEAGE_FILTRATION_3, | |||||
| LINEAGE_FILTRATION_4, | |||||
| LINEAGE_FILTRATION_6 | |||||
| ], | ], | ||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| @@ -631,82 +503,13 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | 'object': [ | ||||
| { | |||||
| "summary_dir": '/path/to/summary6', | |||||
| "loss_function": None, | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": None, | |||||
| "test_dataset_path": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], | |||||
| "test_dataset_count": | |||||
| event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], | |||||
| "network": None, | |||||
| "optimizer": None, | |||||
| "learning_rate": None, | |||||
| "epoch": None, | |||||
| "batch_size": None, | |||||
| "loss": None, | |||||
| "model_size": None, | |||||
| "metric": event_data.METRIC_5, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| }, | |||||
| create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary3', | |||||
| event_data.EVENT_TRAIN_DICT_3, | |||||
| event_data.EVENT_EVAL_DICT_3, | |||||
| event_data.METRIC_3, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary2', | |||||
| event_data.EVENT_TRAIN_DICT_2, | |||||
| event_data.EVENT_EVAL_DICT_2, | |||||
| event_data.METRIC_2, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary1', | |||||
| event_data.EVENT_TRAIN_DICT_1, | |||||
| event_data.EVENT_EVAL_DICT_1, | |||||
| event_data.METRIC_1, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| create_filtration_result( | |||||
| '/path/to/summary0', | |||||
| event_data.EVENT_TRAIN_DICT_0, | |||||
| event_data.EVENT_EVAL_DICT_0, | |||||
| event_data.METRIC_0, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| { | |||||
| "summary_dir": '/path/to/summary5', | |||||
| "loss_function": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], | |||||
| "train_dataset_path": None, | |||||
| "train_dataset_count": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], | |||||
| "test_dataset_path": None, | |||||
| "test_dataset_count": None, | |||||
| "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], | |||||
| "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], | |||||
| "learning_rate": | |||||
| event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], | |||||
| "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], | |||||
| "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], | |||||
| "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], | |||||
| "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], | |||||
| "metric": {}, | |||||
| "dataset_graph": event_data.DATASET_DICT_0, | |||||
| "dataset_mark": '2' | |||||
| } | |||||
| LINEAGE_FILTRATION_6, | |||||
| LINEAGE_FILTRATION_4, | |||||
| LINEAGE_FILTRATION_3, | |||||
| LINEAGE_FILTRATION_2, | |||||
| LINEAGE_FILTRATION_1, | |||||
| LINEAGE_FILTRATION_0, | |||||
| LINEAGE_FILTRATION_5 | |||||
| ], | ], | ||||
| 'count': 7, | 'count': 7, | ||||
| } | } | ||||
| @@ -722,15 +525,7 @@ class TestQuerier(TestCase): | |||||
| } | } | ||||
| } | } | ||||
| expected_result = { | expected_result = { | ||||
| 'object': [ | |||||
| create_filtration_result( | |||||
| '/path/to/summary4', | |||||
| event_data.EVENT_TRAIN_DICT_4, | |||||
| event_data.EVENT_EVAL_DICT_4, | |||||
| event_data.METRIC_4, | |||||
| event_data.DATASET_DICT_0 | |||||
| ), | |||||
| ], | |||||
| 'object': [LINEAGE_FILTRATION_4], | |||||
| 'count': 1, | 'count': 1, | ||||
| } | } | ||||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | result = self.multi_querier.filter_summary_lineage(condition=condition) | ||||
| @@ -809,20 +604,8 @@ class TestQuerier(TestCase): | |||||
| querier = Querier(summary_path) | querier = Querier(summary_path) | ||||
| querier._parse_failed_paths.append('/path/to/summary1/log1') | querier._parse_failed_paths.append('/path/to/summary1/log1') | ||||
| expected_result = [ | expected_result = [ | ||||
| { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| }, | |||||
| { | |||||
| 'summary_dir': '/path/to/summary1', | |||||
| **event_data.EVENT_TRAIN_DICT_1['train_lineage'], | |||||
| 'metric': event_data.METRIC_1, | |||||
| 'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| LINEAGE_INFO_0, | |||||
| LINEAGE_INFO_1 | |||||
| ] | ] | ||||
| result = querier.get_summary_lineage() | result = querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | self.assertListEqual(expected_result, result) | ||||
| @@ -842,17 +625,7 @@ class TestQuerier(TestCase): | |||||
| querier._parse_failed_paths.append('/path/to/summary1/log1') | querier._parse_failed_paths.append('/path/to/summary1/log1') | ||||
| args[0].return_value = create_lineage_info(None, None, None) | args[0].return_value = create_lineage_info(None, None, None) | ||||
| expected_result = [ | |||||
| { | |||||
| 'summary_dir': '/path/to/summary0', | |||||
| **event_data.EVENT_TRAIN_DICT_0['train_lineage'], | |||||
| 'metric': event_data.METRIC_0, | |||||
| 'valid_dataset': | |||||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][ | |||||
| 'valid_dataset'], | |||||
| 'dataset_graph': event_data.DATASET_DICT_0 | |||||
| } | |||||
| ] | |||||
| expected_result = [LINEAGE_INFO_0] | |||||
| result = querier.get_summary_lineage() | result = querier.get_summary_lineage() | ||||
| self.assertListEqual(expected_result, result) | self.assertListEqual(expected_result, result) | ||||
| self.assertListEqual( | self.assertListEqual( | ||||
| @@ -15,11 +15,12 @@ | |||||
| """Test the query_model module.""" | """Test the query_model module.""" | ||||
| from unittest import TestCase | from unittest import TestCase | ||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import \ | |||||
| LineageEventNotExistException, LineageEventFieldNotExistException | |||||
| from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException, | |||||
| LineageEventNotExistException) | |||||
| from mindinsight.lineagemgr.querier.query_model import LineageObj | from mindinsight.lineagemgr.querier.query_model import LineageObj | ||||
| from . import event_data | from . import event_data | ||||
| from .test_querier import create_lineage_info, create_filtration_result | |||||
| from .test_querier import create_filtration_result, create_lineage_info | |||||
| class TestLineageObj(TestCase): | class TestLineageObj(TestCase): | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -19,10 +19,10 @@ import time | |||||
| from google.protobuf import json_format | from google.protobuf import json_format | ||||
| from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | ||||
| from .log_generator import LogGenerator | |||||
| class GraphLogGenerator(LogGenerator): | class GraphLogGenerator(LogGenerator): | ||||
| """ | """ | ||||
| @@ -74,7 +74,7 @@ class GraphLogGenerator(LogGenerator): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| graph_log_generator = GraphLogGenerator() | graph_log_generator = GraphLogGenerator() | ||||
| test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) | test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) | ||||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") | |||||
| graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators--", "graph_base.json") | |||||
| with open(graph_base_path, 'r') as load_f: | with open(graph_base_path, 'r') as load_f: | ||||
| graph = json.load(load_f) | graph = json.load(load_f) | ||||
| graph_log_generator.generate_log(test_file_name, graph) | graph_log_generator.generate_log(test_file_name, graph) | ||||
| @@ -18,10 +18,11 @@ import time | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | from PIL import Image | ||||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | ||||
| from .log_generator import LogGenerator | |||||
| class ImagesLogGenerator(LogGenerator): | class ImagesLogGenerator(LogGenerator): | ||||
| """ | """ | ||||
| @@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator): | |||||
| images_metadata.append(image_metadata) | images_metadata.append(image_metadata) | ||||
| images_values.update({step: image_tensor}) | images_values.update({step: image_tensor}) | ||||
| values = dict( | |||||
| wall_time=wall_time, | |||||
| step=step, | |||||
| image=image_tensor, | |||||
| tag=tag_name | |||||
| ) | |||||
| values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name) | |||||
| self._write_log_one_step(file_path, values) | self._write_log_one_step(file_path, values) | ||||
| @@ -17,7 +17,7 @@ | |||||
| import struct | import struct | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from tests.st.func.datavisual.utils import crc32 | |||||
| from ...utils import crc32 | |||||
| class LogGenerator: | class LogGenerator: | ||||
| @@ -16,10 +16,11 @@ | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator | |||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | ||||
| from .log_generator import LogGenerator | |||||
| class ScalarsLogGenerator(LogGenerator): | class ScalarsLogGenerator(LogGenerator): | ||||
| """ | """ | ||||
| @@ -19,13 +19,12 @@ import json | |||||
| import os | import os | ||||
| import time | import time | ||||
| from tests.st.func.datavisual.constants import SUMMARY_PREFIX | |||||
| from tests.st.func.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator | |||||
| from tests.st.func.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator | |||||
| from tests.st.func.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator | |||||
| from mindinsight.datavisual.common.enums import PluginNameEnum | from mindinsight.datavisual.common.enums import PluginNameEnum | ||||
| from .log_generators.graph_log_generator import GraphLogGenerator | |||||
| from .log_generators.images_log_generator import ImagesLogGenerator | |||||
| from .log_generators.scalars_log_generator import ScalarsLogGenerator | |||||
| log_generators = { | log_generators = { | ||||
| PluginNameEnum.GRAPH.value: GraphLogGenerator(), | PluginNameEnum.GRAPH.value: GraphLogGenerator(), | ||||
| PluginNameEnum.IMAGE.value: ImagesLogGenerator(), | PluginNameEnum.IMAGE.value: ImagesLogGenerator(), | ||||
| @@ -35,10 +34,12 @@ log_generators = { | |||||
| class LogOperations: | class LogOperations: | ||||
| """Log Operations.""" | """Log Operations.""" | ||||
| def __init__(self): | def __init__(self): | ||||
| self._step_num = 3 | self._step_num = 3 | ||||
| self._tag_num = 2 | self._tag_num = 2 | ||||
| self._time_count = 0 | self._time_count = 0 | ||||
| self._graph_base_path = os.path.join(os.path.dirname(__file__), "log_generators", "graph_base.json") | |||||
| def _get_steps(self): | def _get_steps(self): | ||||
| """Get steps.""" | """Get steps.""" | ||||
| @@ -61,9 +62,7 @@ class LogOperations: | |||||
| metadata_dict["plugins"].update({plugin_name: list()}) | metadata_dict["plugins"].update({plugin_name: list()}) | ||||
| log_generator = log_generators.get(plugin_name) | log_generator = log_generators.get(plugin_name) | ||||
| if plugin_name == PluginNameEnum.GRAPH.value: | if plugin_name == PluginNameEnum.GRAPH.value: | ||||
| graph_base_path = os.path.join(os.path.dirname(__file__), | |||||
| os.pardir, "utils", "log_generators", "graph_base.json") | |||||
| with open(graph_base_path, 'r') as load_f: | |||||
| with open(self._graph_base_path, 'r') as load_f: | |||||
| graph_dict = json.load(load_f) | graph_dict = json.load(load_f) | ||||
| values = log_generator.generate_log(file_path, graph_dict) | values = log_generator.generate_log(file_path, graph_dict) | ||||
| metadata_dict["actual_values"].update({plugin_name: values}) | metadata_dict["actual_values"].update({plugin_name: values}) | ||||
| @@ -82,13 +81,13 @@ class LogOperations: | |||||
| self._time_count += 1 | self._time_count += 1 | ||||
| return metadata_dict | return metadata_dict | ||||
| def create_summary_logs(self, summary_base_dir, summary_dir_num, start_index=0): | |||||
| def create_summary_logs(self, summary_base_dir, summary_dir_num, dir_prefix, start_index=0): | |||||
| """Create summary logs in summary_base_dir.""" | """Create summary logs in summary_base_dir.""" | ||||
| summary_metadata = dict() | summary_metadata = dict() | ||||
| steps_list = self._get_steps() | steps_list = self._get_steps() | ||||
| tag_name_list = self._get_tags() | tag_name_list = self._get_tags() | ||||
| for i in range(start_index, summary_dir_num + start_index): | for i in range(start_index, summary_dir_num + start_index): | ||||
| log_dir = os.path.join(summary_base_dir, f'{SUMMARY_PREFIX}{i}') | |||||
| log_dir = os.path.join(summary_base_dir, f'{dir_prefix}{i}') | |||||
| os.makedirs(log_dir) | os.makedirs(log_dir) | ||||
| train_id = log_dir.replace(summary_base_dir, ".") | train_id = log_dir.replace(summary_base_dir, ".") | ||||
| @@ -120,3 +119,47 @@ class LogOperations: | |||||
| metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list) | metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list) | ||||
| return {train_id: metadata_dict} | return {train_id: metadata_dict} | ||||
| def generate_log(self, plugin_name, log_dir, log_settings=None, valid=True): | |||||
| """ | |||||
| Generate log for ut. | |||||
| Args: | |||||
| plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'. | |||||
| log_dir (str): Log path to write log. | |||||
| log_settings (dict): Info about the log, e.g.: | |||||
| { | |||||
| current_time (int): Timestamp in summary file name, not necessary. | |||||
| graph_base_path (str): Path of graph_bas.json, necessary for `graph`. | |||||
| steps (list[int]): Steps for `image` and `scalar`, default is [1]. | |||||
| tag (str): Tag name, default is 'default_tag'. | |||||
| } | |||||
| valid (bool): If true, summary name will be valid. | |||||
| Returns: | |||||
| str, Summary log path. | |||||
| """ | |||||
| if log_settings is None: | |||||
| log_settings = dict() | |||||
| current_time = log_settings.get('time', int(time.time())) | |||||
| current_time = int(current_time) | |||||
| log_generator = log_generators.get(plugin_name) | |||||
| if valid: | |||||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time))) | |||||
| else: | |||||
| temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time))) | |||||
| if plugin_name == PluginNameEnum.GRAPH.value: | |||||
| with open(self._graph_base_path, 'r') as load_f: | |||||
| graph_dict = json.load(load_f) | |||||
| graph_dict = log_generator.generate_log(temp_path, graph_dict) | |||||
| return temp_path, graph_dict | |||||
| steps_list = log_settings.get('steps', [1]) | |||||
| tag_name = log_settings.get('tag', 'default_tag') | |||||
| metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name) | |||||
| return temp_path, metadata, values | |||||