|
|
@@ -45,7 +45,6 @@ class DebuggerServer: |
|
|
self.grpc_server = DebuggerGrpcServer(self.cache_store) |
|
|
self.grpc_server = DebuggerGrpcServer(self.cache_store) |
|
|
self.grpc_server_manager = None |
|
|
self.grpc_server_manager = None |
|
|
self.back_server = None |
|
|
self.back_server = None |
|
|
self._watch_point_id = 0 |
|
|
|
|
|
|
|
|
|
|
|
def start(self): |
|
|
def start(self): |
|
|
"""Start server.""" |
|
|
"""Start server.""" |
|
|
@@ -95,12 +94,23 @@ class DebuggerServer: |
|
|
|
|
|
|
|
|
return reply |
|
|
return reply |
|
|
|
|
|
|
|
|
def search(self, name, watch_point_id): |
|
|
|
|
|
"""Search for single node in graph.""" |
|
|
|
|
|
|
|
|
def search(self, name, watch_point_id=0): |
|
|
|
|
|
""" |
|
|
|
|
|
Search for single node in graph. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
name (str): The name pattern. |
|
|
|
|
|
watch_point_id (int): The id of watchpoint. Default: 0. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
dict, the searched nodes. |
|
|
|
|
|
""" |
|
|
log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) |
|
|
log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) |
|
|
|
|
|
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) |
|
|
|
|
|
watchpoint_stream.validate_watchpoint_id(watch_point_id) |
|
|
graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) |
|
|
graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) |
|
|
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( |
|
|
|
|
|
graph, watch_point_id) |
|
|
|
|
|
|
|
|
# add watched label to graph |
|
|
|
|
|
watchpoint_stream.set_watch_nodes(graph, watch_point_id) |
|
|
return graph |
|
|
return graph |
|
|
|
|
|
|
|
|
def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): |
|
|
def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): |
|
|
@@ -155,8 +165,6 @@ class DebuggerServer: |
|
|
""" |
|
|
""" |
|
|
log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, |
|
|
log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, |
|
|
filter_condition) |
|
|
filter_condition) |
|
|
# validate watchpoint_id |
|
|
|
|
|
|
|
|
|
|
|
mode_mapping = { |
|
|
mode_mapping = { |
|
|
'all': self._retrieve_all, |
|
|
'all': self._retrieve_all, |
|
|
'node': self._retrieve_node, |
|
|
'node': self._retrieve_node, |
|
|
@@ -178,10 +186,9 @@ class DebuggerServer: |
|
|
if filter_condition: |
|
|
if filter_condition: |
|
|
log.error("No filter condition required for retrieve all request.") |
|
|
log.error("No filter condition required for retrieve all request.") |
|
|
raise DebuggerParamTypeError("filter_condition should be empty.") |
|
|
raise DebuggerParamTypeError("filter_condition should be empty.") |
|
|
result = {} |
|
|
|
|
|
self._watch_point_id = 0 |
|
|
|
|
|
self.cache_store.clean_data() |
|
|
self.cache_store.clean_data() |
|
|
log.info("Clean data queue cache when retrieve all request.") |
|
|
log.info("Clean data queue cache when retrieve all request.") |
|
|
|
|
|
result = {} |
|
|
for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: |
|
|
for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: |
|
|
sub_res = self.cache_store.get_stream_handler(stream).get() |
|
|
sub_res = self.cache_store.get_stream_handler(stream).get() |
|
|
result.update(sub_res) |
|
|
result.update(sub_res) |
|
|
@@ -200,13 +207,15 @@ class DebuggerServer: |
|
|
- single_node (bool): If False, return the sub-layer of single node. If True, return |
|
|
- single_node (bool): If False, return the sub-layer of single node. If True, return |
|
|
the node list from root node to single node. |
|
|
the node list from root node to single node. |
|
|
|
|
|
|
|
|
|
|
|
- watch_point_id (int): The id of watchpoint. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
dict, the node info. |
|
|
|
|
|
|
|
|
dict, reply with graph. |
|
|
""" |
|
|
""" |
|
|
log.info("Retrieve node %s.", filter_condition) |
|
|
log.info("Retrieve node %s.", filter_condition) |
|
|
|
|
|
# validate node name |
|
|
node_name = filter_condition.get('name') |
|
|
node_name = filter_condition.get('name') |
|
|
if node_name: |
|
|
if node_name: |
|
|
# validate node name |
|
|
|
|
|
self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) |
|
|
self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) |
|
|
filter_condition['single_node'] = bool(filter_condition.get('single_node')) |
|
|
filter_condition['single_node'] = bool(filter_condition.get('single_node')) |
|
|
reply = self._get_nodes_info(filter_condition) |
|
|
reply = self._get_nodes_info(filter_condition) |
|
|
@@ -224,16 +233,21 @@ class DebuggerServer: |
|
|
- single_node (bool): If False, return the sub-layer of single node. If True, return |
|
|
- single_node (bool): If False, return the sub-layer of single node. If True, return |
|
|
the node list from root node to single node. |
|
|
the node list from root node to single node. |
|
|
|
|
|
|
|
|
|
|
|
- watch_point_id (int): The id of watchpoint. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
dict, reply with graph. |
|
|
dict, reply with graph. |
|
|
""" |
|
|
""" |
|
|
|
|
|
# validate watch_point_id |
|
|
|
|
|
watch_point_id = filter_condition.get('watch_point_id', 0) |
|
|
|
|
|
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) |
|
|
|
|
|
watchpoint_stream.validate_watchpoint_id(watch_point_id) |
|
|
# get graph |
|
|
# get graph |
|
|
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) |
|
|
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) |
|
|
reply = graph_stream.get(filter_condition) |
|
|
reply = graph_stream.get(filter_condition) |
|
|
graph = reply.get('graph') |
|
|
graph = reply.get('graph') |
|
|
# add watched label |
|
|
|
|
|
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( |
|
|
|
|
|
graph, self._watch_point_id) |
|
|
|
|
|
|
|
|
# add watched label to graph |
|
|
|
|
|
watchpoint_stream.set_watch_nodes(graph, watch_point_id) |
|
|
return reply |
|
|
return reply |
|
|
|
|
|
|
|
|
def retrieve_tensor_history(self, node_name): |
|
|
def retrieve_tensor_history(self, node_name): |
|
|
@@ -353,7 +367,7 @@ class DebuggerServer: |
|
|
Args: |
|
|
Args: |
|
|
filter_condition (dict): Filter condition. |
|
|
filter_condition (dict): Filter condition. |
|
|
|
|
|
|
|
|
- watch_point_id (int): The id of watchoint. If not given, return all watchpoints. |
|
|
|
|
|
|
|
|
- watch_point_id (int): The id of watchpoint. If not given, return all watchpoints. |
|
|
|
|
|
|
|
|
- name (str): The name of single node. |
|
|
- name (str): The name of single node. |
|
|
|
|
|
|
|
|
@@ -363,10 +377,7 @@ class DebuggerServer: |
|
|
Returns: |
|
|
Returns: |
|
|
dict, watch point list or relative graph. |
|
|
dict, watch point list or relative graph. |
|
|
""" |
|
|
""" |
|
|
watchpoint_id = filter_condition.get('watch_point_id') |
|
|
|
|
|
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) |
|
|
|
|
|
watchpoint_stream.validate_watchpoint_id(watchpoint_id) |
|
|
|
|
|
self._watch_point_id = watchpoint_id if watchpoint_id else 0 |
|
|
|
|
|
|
|
|
watchpoint_id = filter_condition.get('watch_point_id', 0) |
|
|
if not watchpoint_id: |
|
|
if not watchpoint_id: |
|
|
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() |
|
|
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() |
|
|
log.debug("Get condition of watchpoints.") |
|
|
log.debug("Get condition of watchpoints.") |
|
|
@@ -392,11 +403,11 @@ class DebuggerServer: |
|
|
dict, watch point list or relative graph. |
|
|
dict, watch point list or relative graph. |
|
|
""" |
|
|
""" |
|
|
node_name = filter_condition.get('name') |
|
|
node_name = filter_condition.get('name') |
|
|
# get watchpoint hit list |
|
|
|
|
|
|
|
|
# get all watchpoint hit list |
|
|
if node_name is None: |
|
|
if node_name is None: |
|
|
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() |
|
|
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() |
|
|
return reply |
|
|
return reply |
|
|
|
|
|
|
|
|
|
|
|
# get tensor history and graph of the hit node. |
|
|
self._validate_leaf_name(node_name) |
|
|
self._validate_leaf_name(node_name) |
|
|
# get tensor history |
|
|
# get tensor history |
|
|
reply = self._get_tensor_history(node_name) |
|
|
reply = self._get_tensor_history(node_name) |
|
|
@@ -438,7 +449,6 @@ class DebuggerServer: |
|
|
watch_nodes = self._get_node_basic_infos(watch_nodes) |
|
|
watch_nodes = self._get_node_basic_infos(watch_nodes) |
|
|
watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( |
|
|
watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( |
|
|
watch_condition, watch_nodes, watch_point_id) |
|
|
watch_condition, watch_nodes, watch_point_id) |
|
|
self._watch_point_id = 0 |
|
|
|
|
|
log.info("Create watchpoint %d", watch_point_id) |
|
|
log.info("Create watchpoint %d", watch_point_id) |
|
|
return {'id': watch_point_id} |
|
|
return {'id': watch_point_id} |
|
|
|
|
|
|
|
|
@@ -462,7 +472,9 @@ class DebuggerServer: |
|
|
raise DebuggerUpdateWatchPointError( |
|
|
raise DebuggerUpdateWatchPointError( |
|
|
"Failed to update watchpoint as the MindSpore is not in waiting state." |
|
|
"Failed to update watchpoint as the MindSpore is not in waiting state." |
|
|
) |
|
|
) |
|
|
# validate |
|
|
|
|
|
|
|
|
# validate parameter |
|
|
|
|
|
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) |
|
|
|
|
|
watchpoint_stream.validate_watchpoint_id(watch_point_id) |
|
|
if not watch_nodes or not watch_point_id: |
|
|
if not watch_nodes or not watch_point_id: |
|
|
log.error("Invalid parameter for update watchpoint.") |
|
|
log.error("Invalid parameter for update watchpoint.") |
|
|
raise DebuggerParamValueError("Invalid parameter for update watchpoint.") |
|
|
raise DebuggerParamValueError("Invalid parameter for update watchpoint.") |
|
|
@@ -472,9 +484,7 @@ class DebuggerServer: |
|
|
elif mode == 1: |
|
|
elif mode == 1: |
|
|
watch_nodes = self._get_node_basic_infos(watch_nodes) |
|
|
watch_nodes = self._get_node_basic_infos(watch_nodes) |
|
|
|
|
|
|
|
|
self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint( |
|
|
|
|
|
watch_point_id, watch_nodes, mode) |
|
|
|
|
|
self._watch_point_id = watch_point_id |
|
|
|
|
|
|
|
|
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode) |
|
|
log.info("Update watchpoint with id: %d", watch_point_id) |
|
|
log.info("Update watchpoint with id: %d", watch_point_id) |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
@@ -510,9 +520,7 @@ class DebuggerServer: |
|
|
raise DebuggerDeleteWatchPointError( |
|
|
raise DebuggerDeleteWatchPointError( |
|
|
"Failed to delete watchpoint as the MindSpore is not in waiting state." |
|
|
"Failed to delete watchpoint as the MindSpore is not in waiting state." |
|
|
) |
|
|
) |
|
|
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint( |
|
|
|
|
|
watch_point_id) |
|
|
|
|
|
self._watch_point_id = 0 |
|
|
|
|
|
|
|
|
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(watch_point_id) |
|
|
log.info("Delete watchpoint with id: %d", watch_point_id) |
|
|
log.info("Delete watchpoint with id: %d", watch_point_id) |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
|