Browse Source

update pull_data and delete mode watchpoint_hit from retrieve

tags/v1.1.0
jiangshuqiang 4 years ago
parent
commit
81113a9227
4 changed files with 2 additions and 75 deletions
  1. +1
    -2
      mindinsight/debugger/debugger_grpc_server.py
  2. +0
    -35
      mindinsight/debugger/debugger_server.py
  3. +1
    -30
      tests/st/func/debugger/test_restful_api.py
  4. +0
    -8
      tests/ut/debugger/test_debugger_server.py

+ 1
- 2
mindinsight/debugger/debugger_grpc_server.py View File

@@ -189,8 +189,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._received_hit = [] self._received_hit = []
for watchpoint_hit in watchpoint_hits: for watchpoint_hit in watchpoint_hits:
watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hit_stream.put(watchpoint_hit)
watchpoint_hits_info = watchpoint_hit_stream.get()
watchpoint_hits_info.update({'receive_watchpoint_hits': True})
watchpoint_hits_info = {'receive_watchpoint_hits': True}
self._cache_store.put_data(watchpoint_hits_info) self._cache_store.put_data(watchpoint_hits_info)
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")




+ 0
- 35
mindinsight/debugger/debugger_server.py View File

@@ -242,7 +242,6 @@ class DebuggerServer:
'all': self._retrieve_all, 'all': self._retrieve_all,
'node': self._retrieve_node, 'node': self._retrieve_node,
'watchpoint': self._retrieve_watchpoint, 'watchpoint': self._retrieve_watchpoint,
'watchpoint_hit': self._retrieve_watchpoint_hit
} }
# validate param <mode> # validate param <mode>
if mode not in mode_mapping.keys(): if mode not in mode_mapping.keys():
@@ -470,40 +469,6 @@ class DebuggerServer:


return reply return reply


def _retrieve_watchpoint_hit(self, filter_condition):
"""
Retrieve watchpoint hit.

Args:
filter_condition (dict): Filter condition.

- name (str): The name of single node.
- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

Returns:
dict, watch point list or relative graph.
"""
node_name = filter_condition.get('name')
# get all watchpoint hit list
if node_name is None:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
return reply
graph_name = self.cache_store.get_stream_handler(Streams.GRAPH).validate_graph_name(
filter_condition.get('graph_name'))
# get tensor history
reply = self._get_tensor_history(node_name, graph_name)
log.debug("Get tensor history for watchpoint hit node.")
# get single graph
if filter_condition.get('single_node'):
filter_condition['graph_name'] = graph_name
graph = self._get_nodes_info(filter_condition)
reply.update(graph)
log.debug("Get tensor history for watchpoint hit node.")

return reply

def search_watchpoint_hits(self, group_condition): def search_watchpoint_hits(self, group_condition):
""" """
Retrieve watchpoint hit. Retrieve watchpoint hit.


+ 1
- 30
tests/st/func/debugger/test_restful_api.py View File

@@ -71,8 +71,7 @@ class TestAscendDebugger:
'retrieve_aggregation_scope_node.json'), 'retrieve_aggregation_scope_node.json'),
({'mode': 'node', 'params': { ({'mode': 'node', 'params': {
'name': 'Default/TransData-op99', 'name': 'Default/TransData-op99',
'single_node': True}}, 'retrieve_single_node.json'),
({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list')
'single_node': True}}, 'retrieve_single_node.json')
]) ])
def test_retrieve_when_train_begin(self, app_client, body_data, expect_file): def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
"""Test retrieve when train_begin.""" """Test retrieve when train_begin."""
@@ -178,34 +177,6 @@ class TestAscendDebugger:
send_and_compare_result(app_client, url, body_data, expect_file, method='get') send_and_compare_result(app_client, url, body_data, expect_file, method='get')
send_terminate_cmd(app_client) send_terminate_cmd(app_client)


@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_watchpoint_hit(self, app_client):
"""Test retrieve watchpoint hit."""
with self._debugger_client.get_thread_instance():
create_watchpoint_and_wait(app_client)
# check watchpoint hit list
url = 'retrieve'
body_data = {'mode': 'watchpoint_hit'}
expect_file = 'retrieve_watchpoint_hit.json'
send_and_compare_result(app_client, url, body_data, expect_file)
# check single watchpoint hit
body_data = {
'mode': 'watchpoint_hit',
'params': {
'name': 'Default/TransData-op99',
'single_node': True,
'watch_point_id': 1
}
}
expect_file = 'retrieve_single_watchpoint_hit.json'
send_and_compare_result(app_client, url, body_data, expect_file)
send_terminate_cmd(app_client)

@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu


+ 0
- 8
tests/ut/debugger/test_debugger_server.py View File

@@ -180,14 +180,6 @@ class TestDebuggerServer:
res = self._server._retrieve_watchpoint({'watch_point_id': 1}) res = self._server._retrieve_watchpoint({'watch_point_id': 1})
assert res == mock_watchpoint assert res == mock_watchpoint


@mock.patch.object(DebuggerServer, '_get_tensor_history')
@mock.patch.object(DebuggerServer, '_get_nodes_info', return_value={'graph': {}})
def test_retrieve_watchpoint_hit(self, *args):
"""Test retrieve single watchpoint."""
args[1].return_value = {'tensor_history': {}}
res = self._server._retrieve_watchpoint_hit({'name': 'hit_node_name', 'single_node': True})
assert res == {'tensor_history': {}, 'graph': {}}

def test_create_watchpoint_with_wrong_state(self): def test_create_watchpoint_with_wrong_state(self):
"""Test create watchpoint with wrong state.""" """Test create watchpoint with wrong state."""
with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):


Loading…
Cancel
Save