| @@ -189,8 +189,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._received_hit = [] | |||
| for watchpoint_hit in watchpoint_hits: | |||
| watchpoint_hit_stream.put(watchpoint_hit) | |||
| watchpoint_hits_info = watchpoint_hit_stream.get() | |||
| watchpoint_hits_info.update({'receive_watchpoint_hits': True}) | |||
| watchpoint_hits_info = {'receive_watchpoint_hits': True} | |||
| self._cache_store.put_data(watchpoint_hits_info) | |||
| log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") | |||
| @@ -242,7 +242,6 @@ class DebuggerServer: | |||
| 'all': self._retrieve_all, | |||
| 'node': self._retrieve_node, | |||
| 'watchpoint': self._retrieve_watchpoint, | |||
| 'watchpoint_hit': self._retrieve_watchpoint_hit | |||
| } | |||
| # validate param <mode> | |||
| if mode not in mode_mapping.keys(): | |||
| @@ -470,40 +469,6 @@ class DebuggerServer: | |||
| return reply | |||
| def _retrieve_watchpoint_hit(self, filter_condition): | |||
| """ | |||
| Retrieve watchpoint hit. | |||
| Args: | |||
| filter_condition (dict): Filter condition. | |||
| - name (str): The name of single node. | |||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||
| the node list from root node to single node. | |||
| Returns: | |||
| dict, watch point list or relative graph. | |||
| """ | |||
| node_name = filter_condition.get('name') | |||
| # get all watchpoint hit list | |||
| if node_name is None: | |||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() | |||
| reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() | |||
| return reply | |||
| graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name( | |||
| filter_condition.get('graph_name')) | |||
| # get tensor history | |||
| reply = self._get_tensor_history(node_name, graph_name) | |||
| log.debug("Get tensor history for watchpoint hit node.") | |||
| # get single graph | |||
| if filter_condition.get('single_node'): | |||
| filter_condition['graph_name'] = graph_name | |||
| graph = self._get_nodes_info(filter_condition) | |||
| reply.update(graph) | |||
| log.debug("Get tensor history for watchpoint hit node.") | |||
| return reply | |||
| def search_watchpoint_hits(self, group_condition): | |||
| """ | |||
| Retrieve watchpoint hit. | |||
| @@ -71,8 +71,7 @@ class TestAscendDebugger: | |||
| 'retrieve_aggregation_scope_node.json'), | |||
| ({'mode': 'node', 'params': { | |||
| 'name': 'Default/TransData-op99', | |||
| 'single_node': True}}, 'retrieve_single_node.json'), | |||
| ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list') | |||
| 'single_node': True}}, 'retrieve_single_node.json') | |||
| ]) | |||
| def test_retrieve_when_train_begin(self, app_client, body_data, expect_file): | |||
| """Test retrieve when train_begin.""" | |||
| @@ -178,34 +177,6 @@ class TestAscendDebugger: | |||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | |||
| send_terminate_cmd(app_client) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_watchpoint_hit(self, app_client): | |||
| """Test retrieve watchpoint hit.""" | |||
| with self._debugger_client.get_thread_instance(): | |||
| create_watchpoint_and_wait(app_client) | |||
| # check watchpoint hit list | |||
| url = 'retrieve' | |||
| body_data = {'mode': 'watchpoint_hit'} | |||
| expect_file = 'retrieve_watchpoint_hit.json' | |||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||
| # check single watchpoint hit | |||
| body_data = { | |||
| 'mode': 'watchpoint_hit', | |||
| 'params': { | |||
| 'name': 'Default/TransData-op99', | |||
| 'single_node': True, | |||
| 'watch_point_id': 1 | |||
| } | |||
| } | |||
| expect_file = 'retrieve_single_watchpoint_hit.json' | |||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||
| send_terminate_cmd(app_client) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -180,14 +180,6 @@ class TestDebuggerServer: | |||
| res = self._server._retrieve_watchpoint({'watch_point_id': 1}) | |||
| assert res == mock_watchpoint | |||
| @mock.patch.object(DebuggerServer, '_get_tensor_history') | |||
| @mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}}) | |||
| def test_retrieve_watchpoint_hit(self, *args): | |||
| """Test retrieve single watchpoint.""" | |||
| args[1].return_value = {'tensor_history': {}} | |||
| res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True}) | |||
| assert res == {'tensor_history': {}, 'graph': {}} | |||
| def test_create_watchpoint_with_wrong_state(self): | |||
| """Test create watchpoint with wrong state.""" | |||
| with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | |||