| @@ -18,19 +18,21 @@ This file is used to define the basic graph. | |||
| import copy | |||
| import time | |||
| from enum import Enum | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common import exceptions | |||
| from .node import NodeTypeEnum | |||
| from .node import Node | |||
| class EdgeTypeEnum: | |||
| class EdgeTypeEnum(Enum): | |||
| """Node edge type enum.""" | |||
| control = 'control' | |||
| data = 'data' | |||
| CONTROL = 'control' | |||
| DATA = 'data' | |||
| class DataTypeEnum: | |||
| class DataTypeEnum(Enum): | |||
| """Data type enum.""" | |||
| DT_TENSOR = 13 | |||
| @@ -292,70 +294,65 @@ class Graph: | |||
| output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value | |||
| node.update_output({dst_name: output_attr}) | |||
| def _calc_polymeric_input_output(self): | |||
| def _update_polymeric_input_output(self): | |||
| """Calc polymeric input and output after build polymeric node.""" | |||
| for name, node in self._normal_nodes.items(): | |||
| polymeric_input = {} | |||
| for src_name in node.input: | |||
| src_node = self._polymeric_nodes.get(src_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| src_name = src_name if not src_node else src_node.polymeric_scope_name | |||
| output_name = self._calc_dummy_node_name(name, src_name) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| continue | |||
| if not src_node: | |||
| continue | |||
| if not node.name_scope and src_node.name_scope: | |||
| # if current node is in first layer, and the src node is not in | |||
| # the first layer, the src node will not be the polymeric input of current node. | |||
| continue | |||
| if node.name_scope == src_node.name_scope \ | |||
| or node.name_scope.startswith(src_node.name_scope): | |||
| polymeric_input.update( | |||
| {src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| for node in self._normal_nodes.values(): | |||
| polymeric_input = self._calc_polymeric_attr(node, 'input') | |||
| node.update_polymeric_input(polymeric_input) | |||
| polymeric_output = {} | |||
| for dst_name in node.output: | |||
| dst_node = self._polymeric_nodes.get(dst_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name | |||
| output_name = self._calc_dummy_node_name(name, dst_name) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| continue | |||
| if not dst_node: | |||
| continue | |||
| if not node.name_scope and dst_node.name_scope: | |||
| continue | |||
| if node.name_scope == dst_node.name_scope \ | |||
| or node.name_scope.startswith(dst_node.name_scope): | |||
| polymeric_output.update( | |||
| {dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_output = self._calc_polymeric_attr(node, 'output') | |||
| node.update_polymeric_output(polymeric_output) | |||
| for name, node in self._polymeric_nodes.items(): | |||
| polymeric_input = {} | |||
| for src_name in node.input: | |||
| output_name = self._calc_dummy_node_name(name, src_name) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| node.update_polymeric_input(polymeric_input) | |||
| polymeric_output = {} | |||
| for dst_name in node.output: | |||
| polymeric_output = {} | |||
| output_name = self._calc_dummy_node_name(name, dst_name) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) | |||
| polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| node.update_polymeric_output(polymeric_output) | |||
| def _calc_polymeric_attr(self, node, attr): | |||
| """ | |||
| Calc polymeric input or polymeric output after build polymeric node. | |||
| Args: | |||
| node (Node): Computes the polymeric input for a given node. | |||
| attr (str): The polymeric attr, optional value is `input` or `output`. | |||
| Returns: | |||
| dict, return polymeric input or polymeric output of the given node. | |||
| """ | |||
| polymeric_attr = {} | |||
| for node_name in getattr(node, attr): | |||
| polymeric_node = self._polymeric_nodes.get(node_name) | |||
| if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: | |||
| node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name | |||
| dummy_node_name = self._calc_dummy_node_name(node.name, node_name) | |||
| polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| continue | |||
| if not polymeric_node: | |||
| continue | |||
| if not node.name_scope and polymeric_node.name_scope: | |||
| # If current node is in top-level layer, and the polymeric_node node is not in | |||
| # the top-level layer, the polymeric node will not be the polymeric input | |||
| # or polymeric output of current node. | |||
| continue | |||
| if node.name_scope == polymeric_node.name_scope \ | |||
| or node.name_scope.startswith(polymeric_node.name_scope + '/'): | |||
| polymeric_attr.update( | |||
| {polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}}) | |||
| return polymeric_attr | |||
| def _calc_dummy_node_name(self, current_node_name, other_node_name): | |||
| """ | |||
| Calc dummy node name. | |||
| @@ -39,7 +39,7 @@ class MSGraph(Graph): | |||
| self._build_leaf_nodes(graph_proto) | |||
| self._build_polymeric_nodes() | |||
| self._build_name_scope_nodes() | |||
| self._calc_polymeric_input_output() | |||
| self._update_polymeric_input_output() | |||
| logger.info("Build graph end, normal node count: %s, polymeric node " | |||
| "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) | |||
| @@ -90,9 +90,9 @@ class MSGraph(Graph): | |||
| node_name = leaf_node_id_map_name[node_def.name] | |||
| node = self._leaf_nodes[node_name] | |||
| for input_def in node_def.input: | |||
| edge_type = EdgeTypeEnum.data | |||
| edge_type = EdgeTypeEnum.DATA.value | |||
| if input_def.type == "CONTROL_EDGE": | |||
| edge_type = EdgeTypeEnum.control | |||
| edge_type = EdgeTypeEnum.CONTROL.value | |||
| if const_nodes_map.get(input_def.name): | |||
| const_node = copy.deepcopy(const_nodes_map[input_def.name]) | |||
| @@ -218,7 +218,7 @@ class MSGraph(Graph): | |||
| node = Node(name=const.key, node_id=const_node_id) | |||
| node.node_type = NodeTypeEnum.CONST.value | |||
| node.update_attr({const.key: str(const.value)}) | |||
| if const.value.dtype == DataTypeEnum.DT_TENSOR: | |||
| if const.value.dtype == DataTypeEnum.DT_TENSOR.value: | |||
| shape = [] | |||
| for dim in const.value.tensor_val.dims: | |||
| shape.append(dim) | |||
| @@ -172,7 +172,7 @@ class Node: | |||
| Args: | |||
| polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: | |||
| {'edge_type': EdgeTypeEnum.data}}). | |||
| {'edge_type': EdgeTypeEnum.DATA.value}}). | |||
| """ | |||
| self._polymeric_output.update(polymeric_output) | |||
| @@ -19,11 +19,11 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' | |||
| @@ -33,12 +33,6 @@ class TestQueryNodes: | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -65,4 +59,5 @@ class TestQueryNodes: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -19,12 +19,11 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' | |||
| @@ -34,12 +33,6 @@ class TestQuerySingleNode: | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -59,4 +52,5 @@ class TestQuerySingleNode: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -19,11 +19,11 @@ Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import os | |||
| import json | |||
| import pytest | |||
| from .. import globals as gbl | |||
| from .....utils.tools import get_url | |||
| from .....utils.tools import get_url, compare_result_with_file | |||
| BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' | |||
| @@ -33,12 +33,6 @@ class TestSearchNodes: | |||
| graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -59,4 +53,5 @@ class TestSearchNodes: | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| self.compare_result_with_file(response.get_json(), result_file) | |||
| file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(response.get_json(), file_path) | |||
| @@ -29,7 +29,7 @@ import pytest | |||
| from ..mock import MockLogger | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| @@ -103,12 +103,6 @@ class TestGraphProcessor: | |||
| # wait for loading done | |||
| check_loading_done(self._mock_data_manager, time_limit=5) | |||
| def compare_result_with_file(self, result, filename): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: | |||
| expected_results = json.load(fp) | |||
| assert result == expected_results | |||
| def test_get_nodes_with_not_exist_train_id(self, load_graph_record): | |||
| """Test getting nodes with not exist train id.""" | |||
| test_train_id = "not_exist_train_id" | |||
| @@ -152,7 +146,9 @@ class TestGraphProcessor: | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| results = graph_processor.get_nodes(name, node_type) | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| @pytest.mark.parametrize("search_content, result_file", [ | |||
| (None, 'test_search_node_names_with_search_content_expected_results1.json'), | |||
| @@ -175,7 +171,8 @@ class TestGraphProcessor: | |||
| expected_results = {'names': []} | |||
| assert results == expected_results | |||
| else: | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| @pytest.mark.parametrize("offset", [-100, -1]) | |||
| def test_search_node_names_with_negative_offset(self, load_graph_record, offset): | |||
| @@ -203,7 +200,8 @@ class TestGraphProcessor: | |||
| results = graph_processor.search_node_names(test_search_content, | |||
| test_offset, | |||
| test_limit) | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| def test_search_node_names_with_wrong_limit(self, load_graph_record): | |||
| """Test search node names with wrong limit.""" | |||
| @@ -227,7 +225,8 @@ class TestGraphProcessor: | |||
| graph_processor = GraphProcessor(self._train_id, | |||
| self._mock_data_manager) | |||
| results = graph_processor.search_single_node(name) | |||
| self.compare_result_with_file(results, result_file) | |||
| expected_file_path = os.path.join(self.graph_results_dir, result_file) | |||
| compare_result_with_file(results, expected_file_path) | |||
| def test_search_single_node_with_not_exist_name(self, load_graph_record): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,9 +19,13 @@ import io | |||
| import os | |||
| import shutil | |||
| import time | |||
| import json | |||
| from urllib.parse import urlencode | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindinsight.datavisual.common.enums import DataManagerStatus | |||
| @@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string): | |||
| image_tensor = np.array(img) | |||
| return image_tensor | |||
| def compare_result_with_file(result, expected_file_path): | |||
| """Compare result with file which contain the expected results.""" | |||
| with open(expected_file_path, 'r') as file: | |||
| expected_results = json.load(file) | |||
| assert result == expected_results | |||