| @@ -189,8 +189,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._received_hit = [] | self._received_hit = [] | ||||
| for watchpoint_hit in watchpoint_hits: | for watchpoint_hit in watchpoint_hits: | ||||
| watchpoint_hit_stream.put(watchpoint_hit) | watchpoint_hit_stream.put(watchpoint_hit) | ||||
| watchpoint_hits_info = watchpoint_hit_stream.get() | |||||
| watchpoint_hits_info.update({'receive_watchpoint_hits': True}) | |||||
| watchpoint_hits_info = {'receive_watchpoint_hits': True} | |||||
| self._cache_store.put_data(watchpoint_hits_info) | self._cache_store.put_data(watchpoint_hits_info) | ||||
| log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") | log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") | ||||
| @@ -242,7 +242,6 @@ class DebuggerServer: | |||||
| 'all': self._retrieve_all, | 'all': self._retrieve_all, | ||||
| 'node': self._retrieve_node, | 'node': self._retrieve_node, | ||||
| 'watchpoint': self._retrieve_watchpoint, | 'watchpoint': self._retrieve_watchpoint, | ||||
| 'watchpoint_hit': self._retrieve_watchpoint_hit | |||||
| } | } | ||||
| # validate param <mode> | # validate param <mode> | ||||
| if mode not in mode_mapping.keys(): | if mode not in mode_mapping.keys(): | ||||
| @@ -470,40 +469,6 @@ class DebuggerServer: | |||||
| return reply | return reply | ||||
| def _retrieve_watchpoint_hit(self, filter_condition): | |||||
| """ | |||||
| Retrieve watchpoint hit. | |||||
| Args: | |||||
| filter_condition (dict): Filter condition. | |||||
| - name (str): The name of single node. | |||||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||||
| the node list from root node to single node. | |||||
| Returns: | |||||
| dict, watch point list or relative graph. | |||||
| """ | |||||
| node_name = filter_condition.get('name') | |||||
| # get all watchpoint hit list | |||||
| if node_name is None: | |||||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() | |||||
| reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() | |||||
| return reply | |||||
| graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name( | |||||
| filter_condition.get('graph_name')) | |||||
| # get tensor history | |||||
| reply = self._get_tensor_history(node_name, graph_name) | |||||
| log.debug("Get tensor history for watchpoint hit node.") | |||||
| # get single graph | |||||
| if filter_condition.get('single_node'): | |||||
| filter_condition['graph_name'] = graph_name | |||||
| graph = self._get_nodes_info(filter_condition) | |||||
| reply.update(graph) | |||||
| log.debug("Get tensor history for watchpoint hit node.") | |||||
| return reply | |||||
| def search_watchpoint_hits(self, group_condition): | def search_watchpoint_hits(self, group_condition): | ||||
| """ | """ | ||||
| Retrieve watchpoint hit. | Retrieve watchpoint hit. | ||||
| @@ -71,8 +71,7 @@ class TestAscendDebugger: | |||||
| 'retrieve_aggregation_scope_node.json'), | 'retrieve_aggregation_scope_node.json'), | ||||
| ({'mode': 'node', 'params': { | ({'mode': 'node', 'params': { | ||||
| 'name': 'Default/TransData-op99', | 'name': 'Default/TransData-op99', | ||||
| 'single_node': True}}, 'retrieve_single_node.json'), | |||||
| ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list') | |||||
| 'single_node': True}}, 'retrieve_single_node.json') | |||||
| ]) | ]) | ||||
| def test_retrieve_when_train_begin(self, app_client, body_data, expect_file): | def test_retrieve_when_train_begin(self, app_client, body_data, expect_file): | ||||
| """Test retrieve when train_begin.""" | """Test retrieve when train_begin.""" | ||||
| @@ -178,34 +177,6 @@ class TestAscendDebugger: | |||||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | send_and_compare_result(app_client, url, body_data, expect_file, method='get') | ||||
| send_terminate_cmd(app_client) | send_terminate_cmd(app_client) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_watchpoint_hit(self, app_client): | |||||
| """Test retrieve watchpoint hit.""" | |||||
| with self._debugger_client.get_thread_instance(): | |||||
| create_watchpoint_and_wait(app_client) | |||||
| # check watchpoint hit list | |||||
| url = 'retrieve' | |||||
| body_data = {'mode': 'watchpoint_hit'} | |||||
| expect_file = 'retrieve_watchpoint_hit.json' | |||||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||||
| # check single watchpoint hit | |||||
| body_data = { | |||||
| 'mode': 'watchpoint_hit', | |||||
| 'params': { | |||||
| 'name': 'Default/TransData-op99', | |||||
| 'single_node': True, | |||||
| 'watch_point_id': 1 | |||||
| } | |||||
| } | |||||
| expect_file = 'retrieve_single_watchpoint_hit.json' | |||||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||||
| send_terminate_cmd(app_client) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @@ -180,14 +180,6 @@ class TestDebuggerServer: | |||||
| res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | ||||
| assert res == mock_watchpoint | assert res == mock_watchpoint | ||||
| @mock.patch.object(DebuggerServer, '_get_tensor_history') | |||||
| @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) | |||||
| def test_retrieve_watchpoint_hit(self, *args): | |||||
| """Test retrieve single watchpoint.""" | |||||
| args[1].return_value = {'tensor_history': {}} | |||||
| res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) | |||||
| assert res == {'tensor_history': {}, 'graph': {}} | |||||
| def test_create_watchpoint_with_wrong_state(self): | def test_create_watchpoint_with_wrong_state(self): | ||||
| """Test create watchpoint with wrong state.""" | """Test create watchpoint with wrong state.""" | ||||
| with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | ||||