From 8ea4d18cec9aa8a378e516f7014ac7e2a9b36391 Mon Sep 17 00:00:00 2001 From: ougongchang Date: Mon, 30 Mar 2020 14:15:15 +0800 Subject: [PATCH] Extract the common function methods and reduced cyclomatic complexity of functions --- .../datavisual/data_transform/graph/graph.py | 99 +++++++++---------- .../data_transform/graph/msgraph.py | 8 +- .../datavisual/data_transform/graph/node.py | 2 +- .../graph/test_query_nodes_restful_api.py | 13 +-- .../test_query_single_nodes_restful_api.py | 12 +-- .../graph/test_search_nodes_restful_api.py | 13 +-- .../processors/test_graph_processor.py | 21 ++-- tests/utils/tools.py | 13 ++- 8 files changed, 86 insertions(+), 95 deletions(-) diff --git a/mindinsight/datavisual/data_transform/graph/graph.py b/mindinsight/datavisual/data_transform/graph/graph.py index 4bdaf139..01cdcf93 100644 --- a/mindinsight/datavisual/data_transform/graph/graph.py +++ b/mindinsight/datavisual/data_transform/graph/graph.py @@ -18,19 +18,21 @@ This file is used to define the basic graph. import copy import time +from enum import Enum + from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common import exceptions from .node import NodeTypeEnum from .node import Node -class EdgeTypeEnum: +class EdgeTypeEnum(Enum): """Node edge type enum.""" - control = 'control' - data = 'data' + CONTROL = 'control' + DATA = 'data' -class DataTypeEnum: +class DataTypeEnum(Enum): """Data type enum.""" DT_TENSOR = 13 @@ -292,70 +294,65 @@ class Graph: output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value node.update_output({dst_name: output_attr}) - def _calc_polymeric_input_output(self): + def _update_polymeric_input_output(self): """Calc polymeric input and output after build polymeric node.""" - for name, node in self._normal_nodes.items(): - polymeric_input = {} - for src_name in node.input: - src_node = self._polymeric_nodes.get(src_name) - if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: - src_name = src_name if not src_node else src_node.polymeric_scope_name - output_name = self._calc_dummy_node_name(name, src_name) - polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) - continue - - if not src_node: - continue - - if not node.name_scope and src_node.name_scope: - # if current node is in first layer, and the src node is not in - # the first layer, the src node will not be the polymeric input of current node. - continue - - if node.name_scope == src_node.name_scope \ - or node.name_scope.startswith(src_node.name_scope): - polymeric_input.update( - {src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) - + for node in self._normal_nodes.values(): + polymeric_input = self._calc_polymeric_attr(node, 'input') node.update_polymeric_input(polymeric_input) - polymeric_output = {} - for dst_name in node.output: - dst_node = self._polymeric_nodes.get(dst_name) - - if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: - dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name - output_name = self._calc_dummy_node_name(name, dst_name) - polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) - continue - - if not dst_node: - continue - - if not node.name_scope and dst_node.name_scope: - continue - - if node.name_scope == dst_node.name_scope \ - or node.name_scope.startswith(dst_node.name_scope): - polymeric_output.update( - {dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}}) - + polymeric_output = self._calc_polymeric_attr(node, 'output') node.update_polymeric_output(polymeric_output) for name, node in self._polymeric_nodes.items(): polymeric_input = {} for src_name in node.input: output_name = self._calc_dummy_node_name(name, src_name) - polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) + polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) node.update_polymeric_input(polymeric_input) polymeric_output = {} for dst_name in node.output: polymeric_output = {} output_name = self._calc_dummy_node_name(name, dst_name) - polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) + polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}}) node.update_polymeric_output(polymeric_output) + def _calc_polymeric_attr(self, node, attr): + """ + Calc polymeric input or polymeric output after build polymeric node. + + Args: + node (Node): Computes the polymeric input for a given node. + attr (str): The polymeric attr, optional value is `input` or `output`. + + Returns: + dict, return polymeric input or polymeric output of the given node. + """ + polymeric_attr = {} + for node_name in getattr(node, attr): + polymeric_node = self._polymeric_nodes.get(node_name) + if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: + node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name + dummy_node_name = self._calc_dummy_node_name(node.name, node_name) + polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}}) + continue + + if not polymeric_node: + continue + + if not node.name_scope and polymeric_node.name_scope: + # If current node is in top-level layer, and the polymeric_node node is not in + # the top-level layer, the polymeric node will not be the polymeric input + # or polymeric output of current node. + continue + + if node.name_scope == polymeric_node.name_scope \ + or node.name_scope.startswith(polymeric_node.name_scope + '/'): + polymeric_attr.update( + {polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}}) + + return polymeric_attr + def _calc_dummy_node_name(self, current_node_name, other_node_name): """ Calc dummy node name. diff --git a/mindinsight/datavisual/data_transform/graph/msgraph.py b/mindinsight/datavisual/data_transform/graph/msgraph.py index 607002f5..db9891c1 100644 --- a/mindinsight/datavisual/data_transform/graph/msgraph.py +++ b/mindinsight/datavisual/data_transform/graph/msgraph.py @@ -39,7 +39,7 @@ class MSGraph(Graph): self._build_leaf_nodes(graph_proto) self._build_polymeric_nodes() self._build_name_scope_nodes() - self._calc_polymeric_input_output() + self._update_polymeric_input_output() logger.info("Build graph end, normal node count: %s, polymeric node " "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) @@ -90,9 +90,9 @@ class MSGraph(Graph): node_name = leaf_node_id_map_name[node_def.name] node = self._leaf_nodes[node_name] for input_def in node_def.input: - edge_type = EdgeTypeEnum.data + edge_type = EdgeTypeEnum.DATA.value if input_def.type == "CONTROL_EDGE": - edge_type = EdgeTypeEnum.control + edge_type = EdgeTypeEnum.CONTROL.value if const_nodes_map.get(input_def.name): const_node = copy.deepcopy(const_nodes_map[input_def.name]) @@ -218,7 +218,7 @@ class MSGraph(Graph): node = Node(name=const.key, node_id=const_node_id) node.node_type = NodeTypeEnum.CONST.value node.update_attr({const.key: str(const.value)}) - if const.value.dtype == DataTypeEnum.DT_TENSOR: + if const.value.dtype == DataTypeEnum.DT_TENSOR.value: shape = [] for dim in const.value.tensor_val.dims: shape.append(dim) diff --git a/mindinsight/datavisual/data_transform/graph/node.py b/mindinsight/datavisual/data_transform/graph/node.py index 2923a2c8..280db941 100644 --- a/mindinsight/datavisual/data_transform/graph/node.py +++ b/mindinsight/datavisual/data_transform/graph/node.py @@ -172,7 +172,7 @@ class Node: Args: polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: - {'edge_type': EdgeTypeEnum.data}}). + {'edge_type': EdgeTypeEnum.DATA.value}}). """ self._polymeric_output.update(polymeric_output) diff --git a/tests/st/func/datavisual/graph/test_query_nodes_restful_api.py b/tests/st/func/datavisual/graph/test_query_nodes_restful_api.py index da0d7a12..9e86baae 100644 --- a/tests/st/func/datavisual/graph/test_query_nodes_restful_api.py +++ b/tests/st/func/datavisual/graph/test_query_nodes_restful_api.py @@ -19,11 +19,11 @@ Usage: pytest tests/st/func/datavisual """ import os -import json + import pytest from .. import globals as gbl -from .....utils.tools import get_url +from .....utils.tools import get_url, compare_result_with_file BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' @@ -33,12 +33,6 @@ class TestQueryNodes: graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') - def compare_result_with_file(self, result, filename): - """Compare result with file which contain the expected results.""" - with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: - expected_results = json.load(fp) - assert result == expected_results - @pytest.mark.level0 @pytest.mark.env_single @pytest.mark.platform_x86_cpu @@ -65,4 +59,5 @@ class TestQueryNodes: url = get_url(BASE_URL, params) response = client.get(url) assert response.status_code == 200 - self.compare_result_with_file(response.get_json(), result_file) + file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(response.get_json(), file_path) diff --git a/tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py b/tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py index df94df1c..e93420fa 100644 --- a/tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py +++ b/tests/st/func/datavisual/graph/test_query_single_nodes_restful_api.py @@ -19,12 +19,11 @@ Usage: pytest tests/st/func/datavisual """ import os -import json import pytest from .. import globals as gbl -from .....utils.tools import get_url +from .....utils.tools import get_url, compare_result_with_file BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' @@ -34,12 +33,6 @@ class TestQuerySingleNode: graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') - def compare_result_with_file(self, result, filename): - """Compare result with file which contain the expected results.""" - with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: - expected_results = json.load(fp) - assert result == expected_results - @pytest.mark.level0 @pytest.mark.env_single @pytest.mark.platform_x86_cpu @@ -59,4 +52,5 @@ class TestQuerySingleNode: url = get_url(BASE_URL, params) response = client.get(url) assert response.status_code == 200 - self.compare_result_with_file(response.get_json(), result_file) + file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(response.get_json(), file_path) diff --git a/tests/st/func/datavisual/graph/test_search_nodes_restful_api.py b/tests/st/func/datavisual/graph/test_search_nodes_restful_api.py index 0bedd274..3b0365b1 100644 --- a/tests/st/func/datavisual/graph/test_search_nodes_restful_api.py +++ b/tests/st/func/datavisual/graph/test_search_nodes_restful_api.py @@ -19,11 +19,11 @@ Usage: pytest tests/st/func/datavisual """ import os -import json import pytest + from .. import globals as gbl -from .....utils.tools import get_url +from .....utils.tools import get_url, compare_result_with_file BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' @@ -33,12 +33,6 @@ class TestSearchNodes: graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') - def compare_result_with_file(self, result, filename): - """Compare result with file which contain the expected results.""" - with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: - expected_results = json.load(fp) - assert result == expected_results - @pytest.mark.level0 @pytest.mark.env_single @pytest.mark.platform_x86_cpu @@ -59,4 +53,5 @@ class TestSearchNodes: url = get_url(BASE_URL, params) response = client.get(url) assert response.status_code == 200 - self.compare_result_with_file(response.get_json(), result_file) + file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(response.get_json(), file_path) diff --git a/tests/ut/datavisual/processors/test_graph_processor.py b/tests/ut/datavisual/processors/test_graph_processor.py index 29b4313f..e2118ea6 100644 --- a/tests/ut/datavisual/processors/test_graph_processor.py +++ b/tests/ut/datavisual/processors/test_graph_processor.py @@ -29,7 +29,7 @@ import pytest from ..mock import MockLogger from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs +from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum @@ -103,12 +103,6 @@ class TestGraphProcessor: # wait for loading done check_loading_done(self._mock_data_manager, time_limit=5) - def compare_result_with_file(self, result, filename): - """Compare result with file which contain the expected results.""" - with open(os.path.join(self.graph_results_dir, filename), 'r') as fp: - expected_results = json.load(fp) - assert result == expected_results - def test_get_nodes_with_not_exist_train_id(self, load_graph_record): """Test getting nodes with not exist train id.""" test_train_id = "not_exist_train_id" @@ -152,7 +146,9 @@ class TestGraphProcessor: graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) results = graph_processor.get_nodes(name, node_type) - self.compare_result_with_file(results, result_file) + + expected_file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(results, expected_file_path) @pytest.mark.parametrize("search_content, result_file", [ (None, 'test_search_node_names_with_search_content_expected_results1.json'), @@ -175,7 +171,8 @@ class TestGraphProcessor: expected_results = {'names': []} assert results == expected_results else: - self.compare_result_with_file(results, result_file) + expected_file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(results, expected_file_path) @pytest.mark.parametrize("offset", [-100, -1]) def test_search_node_names_with_negative_offset(self, load_graph_record, offset): @@ -203,7 +200,8 @@ class TestGraphProcessor: results = graph_processor.search_node_names(test_search_content, test_offset, test_limit) - self.compare_result_with_file(results, result_file) + expected_file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(results, expected_file_path) def test_search_node_names_with_wrong_limit(self, load_graph_record): """Test search node names with wrong limit.""" @@ -227,7 +225,8 @@ class TestGraphProcessor: graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) results = graph_processor.search_single_node(name) - self.compare_result_with_file(results, result_file) + expected_file_path = os.path.join(self.graph_results_dir, result_file) + compare_result_with_file(results, expected_file_path) def test_search_single_node_with_not_exist_name(self, load_graph_record): diff --git a/tests/utils/tools.py b/tests/utils/tools.py index 885b0796..8dbd45d7 100644 --- a/tests/utils/tools.py +++ b/tests/utils/tools.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,13 @@ import io import os import shutil import time +import json + from urllib.parse import urlencode + import numpy as np from PIL import Image + from mindinsight.datavisual.common.enums import DataManagerStatus @@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string): image_tensor = np.array(img) return image_tensor + + +def compare_result_with_file(result, expected_file_path): + """Compare result with file which contain the expected results.""" + with open(expected_file_path, 'r') as file: + expected_results = json.load(file) + assert result == expected_results