diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index a78aca4c..19cbae58 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -189,8 +189,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._received_hit = [] for watchpoint_hit in watchpoint_hits: watchpoint_hit_stream.put(watchpoint_hit) - watchpoint_hits_info = watchpoint_hit_stream.get() - watchpoint_hits_info.update({'receive_watchpoint_hits': True}) + watchpoint_hits_info = {'receive_watchpoint_hits': True} self._cache_store.put_data(watchpoint_hits_info) log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 1badb368..5b794045 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -242,7 +242,6 @@ class DebuggerServer: 'all': self._retrieve_all, 'node': self._retrieve_node, 'watchpoint': self._retrieve_watchpoint, - 'watchpoint_hit': self._retrieve_watchpoint_hit } # validate param if mode not in mode_mapping.keys(): @@ -470,40 +469,6 @@ class DebuggerServer: return reply - def _retrieve_watchpoint_hit(self, filter_condition): - """ - Retrieve watchpoint hit. - - Args: - filter_condition (dict): Filter condition. - - - name (str): The name of single node. - - single_node (bool): If False, return the sub-layer of single node. If True, return - the node list from root node to single node. - - Returns: - dict, watch point list or relative graph. - """ - node_name = filter_condition.get('name') - # get all watchpoint hit list - if node_name is None: - reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() - reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() - return reply - graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name( - filter_condition.get('graph_name')) - # get tensor history - reply = self._get_tensor_history(node_name, graph_name) - log.debug("Get tensor history for watchpoint hit node.") - # get single graph - if filter_condition.get('single_node'): - filter_condition['graph_name'] = graph_name - graph = self._get_nodes_info(filter_condition) - reply.update(graph) - log.debug("Get tensor history for watchpoint hit node.") - - return reply - def search_watchpoint_hits(self, group_condition): """ Retrieve watchpoint hit. diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index f0b69b56..00e7b4a3 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -71,8 +71,7 @@ class TestAscendDebugger: 'retrieve_aggregation_scope_node.json'), ({'mode': 'node', 'params': { 'name': 'Default/TransData-op99', - 'single_node': True}}, 'retrieve_single_node.json'), - ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list') + 'single_node': True}}, 'retrieve_single_node.json') ]) def test_retrieve_when_train_begin(self, app_client, body_data, expect_file): """Test retrieve when train_begin.""" @@ -178,34 +177,6 @@ class TestAscendDebugger: send_and_compare_result(app_client, url, body_data, expect_file, method='get') send_terminate_cmd(app_client) - @pytest.mark.level0 - @pytest.mark.env_single - @pytest.mark.platform_x86_cpu - @pytest.mark.platform_arm_ascend_training - @pytest.mark.platform_x86_gpu_training - @pytest.mark.platform_x86_ascend_training - def test_watchpoint_hit(self, app_client): - """Test retrieve watchpoint hit.""" - with self._debugger_client.get_thread_instance(): - create_watchpoint_and_wait(app_client) - # check watchpoint hit list - url = 'retrieve' - body_data = {'mode': 'watchpoint_hit'} - expect_file = 'retrieve_watchpoint_hit.json' - send_and_compare_result(app_client, url, body_data, expect_file) - # check single watchpoint hit - body_data = { - 'mode': 'watchpoint_hit', - 'params': { - 'name': 'Default/TransData-op99', - 'single_node': True, - 'watch_point_id': 1 - } - } - expect_file = 'retrieve_single_watchpoint_hit.json' - send_and_compare_result(app_client, url, body_data, expect_file) - send_terminate_cmd(app_client) - @pytest.mark.level0 @pytest.mark.env_single @pytest.mark.platform_x86_cpu diff --git a/tests/ut/debugger/test_debugger_server.py b/tests/ut/debugger/test_debugger_server.py index 506eeeb8..6a08babe 100644 --- a/tests/ut/debugger/test_debugger_server.py +++ b/tests/ut/debugger/test_debugger_server.py @@ -180,14 +180,6 @@ class TestDebuggerServer: res = self._server._retrieve_watchpoint({'watch_point_id': 1}) assert res == mock_watchpoint - @mock.patch.object(DebuggerServer, '_get_tensor_history') - @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) - def test_retrieve_watchpoint_hit(self, *args): - """Test retrieve single watchpoint.""" - args[1].return_value = {'tensor_history': {}} - res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) - assert res == {'tensor_history': {}, 'graph': {}} - def test_create_watchpoint_with_wrong_state(self): """Test create watchpoint with wrong state.""" with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):