diff --git a/mindinsight/debugger/conditionmgr/condition.py b/mindinsight/debugger/conditionmgr/condition.py index ded17fc4..4be8196c 100644 --- a/mindinsight/debugger/conditionmgr/condition.py +++ b/mindinsight/debugger/conditionmgr/condition.py @@ -81,9 +81,9 @@ class ParamTypeEnum(Enum): class ActivationFuncEnum(Enum): """Activation functions.""" - TANH = 'Tanh' - SIGMOID = 'Sigmoid' - RELU = 'ReLU' + TANH = 'tanh' + SIGMOID = 'sigmoid' + RELU = 'relu' class ConditionContext: diff --git a/mindinsight/debugger/stream_cache/node_type_identifier.py b/mindinsight/debugger/stream_cache/node_type_identifier.py index a2637789..e24f9cc2 100644 --- a/mindinsight/debugger/stream_cache/node_type_identifier.py +++ b/mindinsight/debugger/stream_cache/node_type_identifier.py @@ -19,20 +19,21 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError _ACTIVATIONS = [ - 'ELU', - 'FastGelu', - 'GELU', - 'HSigmoid', - 'HSwish', - 'LeakyReLU', - 'LogSigmoid', - 'LogSoftmax', - 'PReLU', - 'ReLU', - 'ReLU6', - 'Sigmoid', - 'Softmax', - 'Tanh' + 'elu', + 'fastgelu', + 'gelu', + 'hsigmoid', + 'hswish', + 'leakyrelu', + 'logsigmoid', + 'logsoftmax', + 'prelu', + 'relu', + 'relu6', + 'reluv2', + 'sigmoid', + 'softmax', + 'tanh' ] @@ -122,7 +123,7 @@ def is_activation_node(node, condition=None): if not is_gradient_node(node): node_type = node.type for activation_name in activation_funcs: - if node_type == activation_name: + if node_type.lower() == activation_name: return True return False diff --git a/tests/st/func/debugger/expect_results/restful_results/search_activation.json b/tests/st/func/debugger/expect_results/restful_results/search_activation.json index d70f0312..2e9a04a9 100644 --- a/tests/st/func/debugger/expect_results/restful_results/search_activation.json +++ b/tests/st/func/debugger/expect_results/restful_results/search_activation.json @@ -1 +1,48 @@ -{"nodes": [{"name": "Default", "type": "name_scope", "nodes": [{"name": "Default/network-WithLossCell", "type": "name_scope", "nodes": [{"name": "Default/network-WithLossCell/_backbone-LeNet5", "type": "name_scope", "nodes": [{"name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", "type": "name_scope", "nodes": [{"name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", "type": "ReLU", "nodes": []}, {"name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", "type": "ReLU", "nodes": []}]}]}]}]}]} \ No newline at end of file +{ + "nodes": [ + { + "name": "Default", + "type": "name_scope", + "nodes": [ + { + "name": "Default/network-WithLossCell", + "type": "name_scope", + "nodes": [ + { + "name": "Default/network-WithLossCell/_backbone-LeNet5", + "type": "name_scope", + "nodes": [ + { + "name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", + "type": "name_scope", + "nodes": [ + { + "name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op87", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", + "type": "ReLU", + "nodes": [] + }, + { + "name": "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", + "type": "ReLU", + "nodes": [] + } + ] + } + ] + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/search_activation_multi_graph.json b/tests/st/func/debugger/expect_results/restful_results/search_activation_multi_graph.json index e6131e3e..15c25fd1 100644 --- a/tests/st/func/debugger/expect_results/restful_results/search_activation_multi_graph.json +++ b/tests/st/func/debugger/expect_results/restful_results/search_activation_multi_graph.json @@ -1 +1,104 @@ -{"nodes": [{"name": "graph_0", "type": "name_scope", "nodes": [{"name": "graph_0/Default", "type": "name_scope", "nodes": [{"name": "graph_0/Default/network-WithLossCell", "type": "name_scope", "nodes": [{"name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5", "type": "name_scope", "nodes": [{"name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", "type": "name_scope", "nodes": [{"name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", "type": "ReLU", "nodes": []}, {"name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", "type": "ReLU", "nodes": []}]}]}]}]}]}, {"name": "graph_1", "type": "name_scope", "nodes": [{"name": "graph_1/Default", "type": "name_scope", "nodes": [{"name": "graph_1/Default/network-WithLossCell", "type": "name_scope", "nodes": [{"name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5", "type": "name_scope", "nodes": [{"name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", "type": "name_scope", "nodes": [{"name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", "type": "ReLU", "nodes": []}, {"name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", "type": "ReLU", "nodes": []}]}]}]}]}]}]} \ No newline at end of file +{ + "nodes": [ + { + "name": "graph_0", + "type": "name_scope", + "nodes": [ + { + "name": "graph_0/Default", + "type": "name_scope", + "nodes": [ + { + "name": "graph_0/Default/network-WithLossCell", + "type": "name_scope", + "nodes": [ + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5", + "type": "name_scope", + "nodes": [ + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", + "type": "name_scope", + "nodes": [ + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op87", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", + "type": "ReLU", + "nodes": [] + }, + { + "name": "graph_0/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", + "type": "ReLU", + "nodes": [] + } + ] + } + ] + } + ] + } + ] + } + ] + }, + { + "name": "graph_1", + "type": "name_scope", + "nodes": [ + { + "name": "graph_1/Default", + "type": "name_scope", + "nodes": [ + { + "name": "graph_1/Default/network-WithLossCell", + "type": "name_scope", + "nodes": [ + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5", + "type": "name_scope", + "nodes": [ + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU", + "type": "name_scope", + "nodes": [ + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op87", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89", + "type": "ReLUV2", + "nodes": [] + }, + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12", + "type": "ReLU", + "nodes": [] + }, + { + "name": "graph_1/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15", + "type": "ReLU", + "nodes": [] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/ut/debugger/stream_cache/test_node_type_identifier.py b/tests/ut/debugger/stream_cache/test_node_type_identifier.py index 971d4092..56350ed9 100644 --- a/tests/ut/debugger/stream_cache/test_node_type_identifier.py +++ b/tests/ut/debugger/stream_cache/test_node_type_identifier.py @@ -60,8 +60,8 @@ class TestNodeTypeIdentifier: ('Default/mock/relu_ReLU-op11', "ReLU", None, True), ('Gradients/mock/relu_ReLU-op11', "ReLU", None, False), ('Default/mock/relu_ReLU-op11', "Parameter", None, False), - ('Default/mock/relu_ReLU-op11', "ReLU", {'activation_func': 'Softmax'}, False), - ('Default/mock/relu_ReLU-op11', "Softmax", {'activation_func': ['ReLU', 'Softmax']}, True) + ('Default/mock/relu_ReLU-op11', "ReLU", {'activation_func': 'softmax'}, False), + ('Default/mock/relu_ReLU-op11', "Softmax", {'activation_func': ['relu', 'softmax']}, True) ]) def test_activate_node(self, name, node_type, condition, result): """Test activate node.""" diff --git a/tests/ut/debugger/stream_handler/test_graph_handler.py b/tests/ut/debugger/stream_handler/test_graph_handler.py index b591f2c1..b840a1cd 100644 --- a/tests/ut/debugger/stream_handler/test_graph_handler.py +++ b/tests/ut/debugger/stream_handler/test_graph_handler.py @@ -73,7 +73,7 @@ class TestGraphHandler: @pytest.mark.parametrize("node_type, condition, result_file", [ ("weight", None, "search_nodes_by_type_0.json"), - ("activation", {'activation_func': ['ReLU', 'Softmax']}, "search_nodes_by_type_1.json") + ("activation", {'activation_func': ['relu', 'softmax']}, "search_nodes_by_type_1.json") ]) def test_search_nodes_by_type(self, node_type, condition, result_file): """Test search nodes by type."""