diff --git a/tests/ut/debugger/configurations.py b/tests/ut/debugger/configurations.py index c9a3fb12..fd5ac870 100644 --- a/tests/ut/debugger/configurations.py +++ b/tests/ut/debugger/configurations.py @@ -18,17 +18,20 @@ import os from google.protobuf import json_format +from mindinsight.datavisual.data_transform.graph import NodeTypeEnum +from mindinsight.debugger.common.utils import NodeBasicInfo from mindinsight.debugger.proto import ms_graph_pb2 from mindinsight.debugger.stream_handler.graph_handler import GraphHandler +from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler -graph_proto_file = os.path.join( +GRAPH_PROTO_FILE = os.path.join( os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb' ) def get_graph_proto(): """Get graph proto.""" - with open(graph_proto_file, 'rb') as f: + with open(GRAPH_PROTO_FILE, 'rb') as f: content = f.read() graph = ms_graph_pb2.GraphProto() @@ -38,7 +41,7 @@ def get_graph_proto(): def init_graph_handler(): - """Init graph proto.""" + """Init GraphHandler.""" graph = get_graph_proto() graph_handler = GraphHandler() graph_handler.put(graph) @@ -46,6 +49,47 @@ def init_graph_handler(): return graph_handler +def init_watchpoint_hit_handler(value): + """Init WatchpointHitHandler.""" + wph_handler = WatchpointHitHandler() + wph_handler.put(value) + + return wph_handler + + +def get_node_basic_infos(node_names): + """Get node info according to node names.""" + if not node_names: + return [] + graph_stream = init_graph_handler() + node_infos = [] + for node_name in node_names: + node_type = graph_stream.get_node_type(node_name) + if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: + sub_nodes = graph_stream.get_nodes_by_scope(node_name) + sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) + for node in sub_nodes] + node_infos.extend(sub_infos) + full_name = graph_stream.get_full_name(node_name) + node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) + return node_infos + + +def get_watch_nodes_by_search(watch_nodes): + """Get watched leaf nodes by search name.""" + watched_leaf_nodes = [] + graph_stream = init_graph_handler() + for search_name in watch_nodes: + search_nodes = graph_stream.get_searched_node_list() + search_node_names = [ + NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) + for node in search_nodes + if node.name.startswith(search_name)] + watched_leaf_nodes.extend(search_node_names) + + return watched_leaf_nodes + + def mock_tensor_proto(): """Mock tensor proto.""" tensor_dict = { diff --git a/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json new file mode 100644 index 00000000..07b9df21 --- /dev/null +++ b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json @@ -0,0 +1,33 @@ +[ + { + "watchCondition": { + "condition": "inf" + }, + "id": 1 + }, + { + "watchCondition": { + "condition": "inf" + }, + "id": 2, + "watchNodes": [ + { + "nodeName": "Default", + "nodeType": "scope" + } + ] + }, + { + "watchCondition": { + "condition": "max_gt", + "value": 1.0 + }, + "id": 3, + "watchNodes": [ + { + "nodeName": "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92", + "nodeType": "leaf" + } + ] + } +] \ No newline at end of file diff --git a/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_1.json b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_1.json new file mode 100644 index 00000000..e62a2adc --- /dev/null +++ b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_1.json @@ -0,0 +1 @@ +[{"id": 1, "watch_condition": {"condition": "INF"}}, {"id": 2, "watch_condition": {"condition": "INF"}}, {"id": 3, "watch_condition": {"condition": "MAX_GT", "param": 1}}] diff --git a/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_0.json b/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_0.json new file mode 100644 index 00000000..fe18c1ca --- /dev/null +++ b/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_0.json @@ -0,0 +1 @@ +{"watch_point_hits": [{"node_name": "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92", "watch_points": [{"id": 1, "watch_condition": {"condition": "MAX_GT", "param": 1}}]}]} diff --git a/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_1.json b/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_1.json new file mode 100644 index 00000000..ec747fa4 --- /dev/null +++ b/tests/ut/debugger/expected_results/watchpoint/watchpoint_hit_handler_get_1.json @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py index 27b093af..67e1f6ec 100644 --- a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py +++ b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py @@ -17,6 +17,7 @@ import json import os from unittest import mock, TestCase +from google.protobuf import json_format import pytest from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ @@ -27,7 +28,9 @@ from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHan WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ - mock_tensor_history + mock_tensor_history, get_node_basic_infos, get_watch_nodes_by_search, \ + init_watchpoint_hit_handler +from tests.utils.tools import compare_result_with_file class TestWatchpointHandler: @@ -36,6 +39,8 @@ class TestWatchpointHandler: def setup_class(cls): """Init WatchpointHandler for watchpoint unittest.""" cls.handler = WatchpointHandler() + cls.results_dir = os.path.join(os.path.dirname(__file__), + '../expected_results/watchpoint') cls.graph_results_dir = os.path.join(os.path.dirname(__file__), '../expected_results/graph') cls.graph_stream = init_graph_handler() @@ -47,26 +52,64 @@ class TestWatchpointHandler: ({'condition': 'MAX_GT', 'param': 1}, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], None, 3) ]) - @mock.patch.object(Watchpoint, 'add_nodes') - def test_create_watchpoint(self, mock_add_nodes, watch_condition, watch_nodes, + def test_create_watchpoint(self, watch_condition, watch_nodes, watch_point_id, expect_new_id): """Test create_watchpoint.""" - mock_add_nodes.return_value = None + watch_nodes = get_node_basic_infos(watch_nodes) watch_point_id = self.handler.create_watchpoint(watch_condition, watch_nodes, watch_point_id) assert watch_point_id == expect_new_id - @pytest.mark.parametrize("filter_condition", [True, False]) - @mock.patch.object(Watchpoint, 'get_set_cmd') - @mock.patch.object(Watchpoint, 'get_watch_condition_info') - def test_get(self, mock_get_wp, mock_get_cmd, filter_condition): - """Test get.""" - mock_get_wp.return_value = None - mock_get_cmd.return_value = None + @pytest.mark.parametrize( + "watch_point_id, watch_nodes, watched, expect_updated_id", [ + (3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3), + (3, [], 1, 3) + ]) + def test_update_watchpoint_add(self, watch_point_id, watch_nodes, watched, expect_updated_id): + """Test update_watchpoint on addition.""" + watch_nodes = get_node_basic_infos(watch_nodes) + with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: + self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) + TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", + log_content.output) + + @pytest.mark.parametrize( + "watch_point_id, watch_nodes, watched, expect_updated_id", [ + (2, ["Default"], 0, 2), + (2, [], 0, 2), + ]) + def test_update_watchpoint_delete(self, watch_point_id, watch_nodes, watched, expect_updated_id): + """Test update_watchpoint on deletion.""" + watch_nodes = get_watch_nodes_by_search(watch_nodes) with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: - self.handler.get(filter_condition) - TestCase().assertIn(f"DEBUG:debugger.debugger:get the watch points with filter_condition:{filter_condition}", + self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) + TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", log_content.output) + @pytest.mark.parametrize("filter_condition, result_file", [ + (True, 'watchpoint_handler_get_0.json') + ]) + def test_get_filter_true(self, filter_condition, result_file): + """Test get with filter_condition is True.""" + file_path = os.path.join(self.results_dir, result_file) + with open(file_path, 'r') as f: + contents = json.load(f) + + reply = self.handler.get(filter_condition) + protos = reply.get('watch_points') + for proto in protos: + msg_dict = json_format.MessageToDict(proto) + assert msg_dict in contents + + @pytest.mark.parametrize("filter_condition, result_file", [ + (False, 'watchpoint_handler_get_1.json') + ]) + def test_get_filter_false(self, filter_condition, result_file): + """Test get with filer_condition is False.""" + file_path = os.path.join(self.results_dir, result_file) + reply = self.handler.get(filter_condition) + watch_points = reply.get('watch_points') + compare_result_with_file(watch_points, file_path) + def test_get_watchpoint_by_id_except(self): """Test get_watchpoint_by_id.""" watchpoint_id = 4 @@ -78,32 +121,12 @@ class TestWatchpointHandler: @pytest.mark.parametrize("graph_file, watch_point_id", [ ('graph_handler_get_3_single_node.json', 4) ]) - @mock.patch.object(WatchpointHandler, '_set_watch_status_recursively') - def test_set_watch_nodes(self, mock_set_recur, graph_file, watch_point_id): + def test_set_watch_nodes(self, graph_file, watch_point_id): """Test set_watch_nodes.""" path = os.path.join(self.graph_results_dir, graph_file) with open(path, 'r') as f: graph = json.load(f) - instance = mock_set_recur.return_value self.handler.set_watch_nodes(graph, self.graph_stream, watch_point_id) - assert instance.iscalled() - - @pytest.mark.parametrize( - "watch_point_id, watch_nodes, watched, expect_updated_id", [ - (2, ["Default"], 0, 2), - (3, ["Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92"], 1, 3) - ]) - @mock.patch.object(Watchpoint, 'remove_nodes') - @mock.patch.object(Watchpoint, 'add_nodes') - def test_update_watchpoint(self, mock_add_nodes, mock_remove_nodes, watch_point_id, watch_nodes, - watched, expect_updated_id): - """Test update_watchpoint.""" - mock_add_nodes.return_value = None - mock_remove_nodes.return_value = None - with TestCase().assertLogs(logger=log, level='DEBUG') as log_content: - self.handler.update_watchpoint(watch_point_id, watch_nodes, watched) - TestCase().assertIn(f"DEBUG:debugger.debugger:Update watchpoint {expect_updated_id} in cache.", - log_content.output) @pytest.mark.parametrize( "watch_point_id, expect_deleted_ids", [ @@ -119,42 +142,52 @@ class TestWatchpointHandler: class TestWatchpointHitHandler: """Test WatchpointHitHandler.""" + watchpoint = Watchpoint( + watch_condition={'condition': 'MAX_GT', 'param': 1}, + watchpoint_id=1 + ) + value = { + 'tensor_proto': mock_tensor_proto(), + 'watchpoint': watchpoint, + 'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92', + 'finished': True, + 'slot': 0 + } + @classmethod def setup_class(cls): """Setup.""" - cls.handler = WatchpointHitHandler() - cls.tensor_proto = mock_tensor_proto() + cls.handler = init_watchpoint_hit_handler(cls.value) cls.tensor_hist = mock_tensor_history() + cls.results_dir = os.path.join(os.path.dirname(__file__), + '../expected_results/watchpoint') @mock.patch('mindinsight.debugger.stream_cache.watchpoint.WatchpointHit') def test_put(self, mock_hit): """Test put.""" - value = { - 'tensor_proto': self.tensor_proto, - 'watchpoint': {'id': 1, 'watch_condition': {'condition': 'INF'}}, - 'node_name': 'Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92' - } mock_hit.return_value = mock.MagicMock( - tensor_proto=value.get('tensor_proto'), - watchpoint=value.get('watchpoint'), - node_name=value.get('node_name') + tensor_proto=self.value.get('tensor_proto'), + watchpoint=self.value.get('watchpoint'), + node_name=self.value.get('node_name') ) - self.handler.put(value) + WatchpointHitHandler().put(self.value) - @pytest.mark.parametrize("filter_condition", [ - None, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190" + @pytest.mark.parametrize("filter_condition, result_file", [ + (None, "watchpoint_hit_handler_get_0.json"), + ("Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190", + "watchpoint_hit_handler_get_1.json") ]) - @mock.patch.object(WatchpointHitHandler, 'get_watchpoint_hits') - def test_get(self, mock_get, filter_condition): + def test_get(self, filter_condition, result_file): """Test get.""" - mock_get.return_value = {'watch_point_hits': []} - self.handler.get(filter_condition) + reply = self.handler.get(filter_condition) + file_path = os.path.join(self.results_dir, result_file) + compare_result_with_file(reply, file_path) - @mock.patch.object(WatchpointHitHandler, '_is_tensor_hit') - def test_update_tensor_history(self, mock_hit): + def test_update_tensor_history(self): """Test update_tensor_history.""" - mock_hit.side_effect = [True, False] self.handler.update_tensor_history(self.tensor_hist) + for tensor_info in self.tensor_hist.get('tensor_history'): + assert tensor_info['is_hit'] is False def test_validate_watch_condition_type_error():