|
|
|
@@ -333,12 +333,13 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
# dict of <ui node_name, dict of <slot, WatchpointHit>>, |
|
|
|
self._hits = {} |
|
|
|
self._ordered_hits = [] |
|
|
|
self._multi_graph_hits = {} |
|
|
|
|
|
|
|
@property |
|
|
|
def empty(self): |
|
|
|
"""Whether the watchpoint hit is empty.""" |
|
|
|
return not self._hits |
|
|
|
return not self._multi_graph_hits |
|
|
|
|
|
|
|
def put(self, value): |
|
|
|
""" |
|
|
|
@@ -351,6 +352,7 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
- watchpoint (Watchpoint): The Watchpoint that a node hit. |
|
|
|
- node_name (str): The UI node name. |
|
|
|
- graph_name (str): The graph name. |
|
|
|
- error_code (int): The code of errors. |
|
|
|
""" |
|
|
|
watchpoint_hit = WatchpointHit( |
|
|
|
tensor_proto=value.get('tensor_proto'), |
|
|
|
@@ -361,12 +363,12 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
if 'error_code' in value.keys(): |
|
|
|
watchpoint_hit.error_code = value.get('error_code') |
|
|
|
# get all hit watchpoints according to node name ans tensor slot |
|
|
|
watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.node_name, |
|
|
|
watchpoint_hits = self._get_watchpoints_by_tensor_name(watchpoint_hit.graph_name, watchpoint_hit.node_name, |
|
|
|
watchpoint_hit.slot) |
|
|
|
if watchpoint_hit not in watchpoint_hits: |
|
|
|
watchpoint_hits.append(watchpoint_hit) |
|
|
|
|
|
|
|
def _get_watchpoints_by_tensor_name(self, node_name, slot): |
|
|
|
def _get_watchpoints_by_tensor_name(self, graph_name, node_name, slot): |
|
|
|
""" |
|
|
|
Get hit tensors according to ui node name and slot. |
|
|
|
|
|
|
|
@@ -377,14 +379,19 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
Returns: |
|
|
|
list, list of watchpoints. |
|
|
|
""" |
|
|
|
hit_node = self._hits.get(node_name) |
|
|
|
if hit_node is None: |
|
|
|
index = self._multi_graph_hits.get((graph_name, node_name)) |
|
|
|
if index is None: |
|
|
|
hit_node = {} |
|
|
|
self._hits[node_name] = hit_node |
|
|
|
self._ordered_hits.append(hit_node) |
|
|
|
index = len(self._ordered_hits) - 1 |
|
|
|
self._multi_graph_hits[(graph_name, node_name)] = index |
|
|
|
|
|
|
|
hit_node = self._ordered_hits[index] |
|
|
|
hit_tensors = hit_node.get(slot) |
|
|
|
if hit_tensors is None: |
|
|
|
hit_tensors = [] |
|
|
|
hit_node[slot] = hit_tensors |
|
|
|
|
|
|
|
return hit_tensors |
|
|
|
|
|
|
|
def get(self, filter_condition=None): |
|
|
|
@@ -398,34 +405,108 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
Returns: |
|
|
|
dict, the watchpoint hit list. |
|
|
|
""" |
|
|
|
reply = None |
|
|
|
if filter_condition is None: |
|
|
|
log.debug("Get all watchpoint hit list.") |
|
|
|
reply = self.get_watchpoint_hits() |
|
|
|
else: |
|
|
|
log.debug("Get the watchpoint for node: <%s>.", filter_condition) |
|
|
|
reply = self._hits.get(filter_condition) |
|
|
|
index = self._multi_graph_hits.get(("", filter_condition)) |
|
|
|
if index is not None: |
|
|
|
reply = self._ordered_hits[index] |
|
|
|
return reply |
|
|
|
|
|
|
|
def group_by(self, group_condition): |
|
|
|
""" |
|
|
|
Return the watchpoint hits by group condition. |
|
|
|
|
|
|
|
Args: |
|
|
|
group_condition (dict): The group conditions. |
|
|
|
|
|
|
|
- limit (int): The limit number of watchpoint hits each page. |
|
|
|
- offset (int): The page offset. |
|
|
|
- node_name (str): The node name. |
|
|
|
- graph_name (str): The graph name. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, the watchpoint hit list. |
|
|
|
""" |
|
|
|
node_name = group_condition.get('node_name') |
|
|
|
# get all watchpoint hit list |
|
|
|
if node_name is None: |
|
|
|
reply = self._get_by_offset(group_condition) |
|
|
|
else: |
|
|
|
reply = self._get_by_name(group_condition) |
|
|
|
return reply |
|
|
|
|
|
|
|
def _get_by_offset(self, group_condition): |
|
|
|
"""Return the list of watchpoint hits on the offset page.""" |
|
|
|
limit = group_condition.get('limit') |
|
|
|
offset = group_condition.get('offset') |
|
|
|
if not isinstance(limit, int) or not isinstance(offset, int): |
|
|
|
log.error("Param limit or offset is not a integer") |
|
|
|
raise DebuggerParamValueError("Param limit or offset is not a integer") |
|
|
|
watch_point_hits = [] |
|
|
|
|
|
|
|
total = len(self._ordered_hits) |
|
|
|
|
|
|
|
if limit * offset >= total and offset != 0: |
|
|
|
log.error("Param offset out of bounds") |
|
|
|
raise DebuggerParamValueError("Param offset out of bounds") |
|
|
|
|
|
|
|
if total == 0: |
|
|
|
return {} |
|
|
|
|
|
|
|
for watchpoint_hits in self._ordered_hits[(limit * offset): (limit * (offset + 1))]: |
|
|
|
self._get_tensors(watchpoint_hits, watch_point_hits) |
|
|
|
|
|
|
|
return { |
|
|
|
'watch_point_hits': watch_point_hits, |
|
|
|
'offset': offset, |
|
|
|
'total': total |
|
|
|
} |
|
|
|
|
|
|
|
def _get_by_name(self, group_condition): |
|
|
|
"""Return the list of watchpoint hits by the group condition.""" |
|
|
|
limit = group_condition.get('limit') |
|
|
|
if not isinstance(limit, int) or limit == 0: |
|
|
|
log.error("Param limit is 0 or not a integer") |
|
|
|
raise DebuggerParamValueError("Param limit is 0 or not a integer") |
|
|
|
|
|
|
|
index = self._multi_graph_hits.get((group_condition.get('graph_name'), group_condition.get('node_name'))) |
|
|
|
if index is not None: |
|
|
|
group_condition['offset'] = index//limit |
|
|
|
return self._get_by_offset(group_condition) |
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
def get_watchpoint_hits(self): |
|
|
|
"""Return the list of watchpoint hits.""" |
|
|
|
watch_point_hits = [] |
|
|
|
for node_name, watchpoint_hits in self._hits.items(): |
|
|
|
tensors = [] |
|
|
|
graph_name = None |
|
|
|
for slot, tensor_hits in watchpoint_hits.items(): |
|
|
|
if graph_name is None: |
|
|
|
graph_name = tensor_hits[0].graph_name |
|
|
|
tensor_info = self._get_tensor_hit_info(slot, tensor_hits) |
|
|
|
tensors.append(tensor_info) |
|
|
|
watch_point_hits.append({ |
|
|
|
'node_name': node_name, |
|
|
|
'tensors': tensors, |
|
|
|
'graph_name': graph_name |
|
|
|
}) |
|
|
|
for watchpoint_hits in self._ordered_hits: |
|
|
|
self._get_tensors(watchpoint_hits, watch_point_hits) |
|
|
|
|
|
|
|
return {'watch_point_hits': watch_point_hits} |
|
|
|
|
|
|
|
def _get_tensors(self, watchpoint_hits, watch_point_hits): |
|
|
|
"""Get the tensors info for the watchpoint_hits.""" |
|
|
|
tensors = [] |
|
|
|
graph_name = None |
|
|
|
node_name = None |
|
|
|
for slot, tensor_hits in watchpoint_hits.items(): |
|
|
|
if graph_name is None: |
|
|
|
graph_name = tensor_hits[0].graph_name |
|
|
|
if node_name is None: |
|
|
|
node_name = tensor_hits[0].node_name |
|
|
|
tensor_info = self._get_tensor_hit_info(slot, tensor_hits) |
|
|
|
tensors.append(tensor_info) |
|
|
|
|
|
|
|
watch_point_hits.append({ |
|
|
|
'node_name': node_name, |
|
|
|
'tensors': tensors, |
|
|
|
'graph_name': graph_name |
|
|
|
}) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_tensor_hit_info(slot, tensor_hits): |
|
|
|
""" |
|
|
|
@@ -457,19 +538,23 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
} |
|
|
|
return res |
|
|
|
|
|
|
|
def _is_tensor_hit(self, tensor_name): |
|
|
|
def _is_tensor_hit(self, tensor_name, graph_name): |
|
|
|
""" |
|
|
|
Check if the tensor is record in hit cache. |
|
|
|
|
|
|
|
Args: |
|
|
|
tensor_name (str): The name of ui tensor name. |
|
|
|
graph_name (str): The name of ui graph name |
|
|
|
|
|
|
|
Returns: |
|
|
|
bool, if the tensor is hit. |
|
|
|
""" |
|
|
|
node_name, slot = tensor_name.rsplit(':', 1) |
|
|
|
watchpoint_hits = self._hits.get(node_name, {}).get(slot) |
|
|
|
return bool(watchpoint_hits) |
|
|
|
index = self._multi_graph_hits.get((graph_name, node_name)) |
|
|
|
if index is not None: |
|
|
|
watchpoint_hits = self._ordered_hits[index].get(slot) |
|
|
|
return bool(watchpoint_hits) |
|
|
|
return False |
|
|
|
|
|
|
|
def update_tensor_history(self, tensor_history): |
|
|
|
""" |
|
|
|
@@ -478,16 +563,17 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
Args: |
|
|
|
tensor_history (dict): The tensor history. |
|
|
|
""" |
|
|
|
if not self._hits: |
|
|
|
if not self._multi_graph_hits: |
|
|
|
return |
|
|
|
|
|
|
|
# add hit tensor names to `tensor_names` |
|
|
|
for tensor_info in tensor_history.get('tensor_history'): |
|
|
|
tensor_name = tensor_info['name'] |
|
|
|
hit_flag = self._is_tensor_hit(tensor_name) |
|
|
|
graph_name = tensor_info['graph_name'] |
|
|
|
hit_flag = self._is_tensor_hit(tensor_name, graph_name) |
|
|
|
tensor_info['is_hit'] = hit_flag |
|
|
|
|
|
|
|
def get_tensor_hit_infos(self, tensor_name): |
|
|
|
def get_tensor_hit_infos(self, tensor_name, graph_name): |
|
|
|
""" |
|
|
|
Get all hit information of a tensor. |
|
|
|
|
|
|
|
@@ -498,9 +584,9 @@ class WatchpointHitHandler(StreamHandlerBase): |
|
|
|
dict, tensor hit info. |
|
|
|
""" |
|
|
|
tensor_hit_info = {} |
|
|
|
if self._is_tensor_hit(tensor_name): |
|
|
|
if self._is_tensor_hit(tensor_name, graph_name): |
|
|
|
node_name, slot = tensor_name.rsplit(':', 1) |
|
|
|
tensor_hits = self._get_watchpoints_by_tensor_name(node_name, slot) |
|
|
|
tensor_hits = self._get_watchpoints_by_tensor_name(graph_name, node_name, slot) |
|
|
|
tensor_hit_info = self._get_tensor_hit_info(slot, tensor_hits) |
|
|
|
return tensor_hit_info |
|
|
|
|
|
|
|
@@ -644,7 +730,7 @@ def _get_error_list(error_code): |
|
|
|
""" |
|
|
|
Get error list. |
|
|
|
Args: |
|
|
|
error_code (int): the code of errors. |
|
|
|
error_code (int): The code of errors. |
|
|
|
|
|
|
|
Returns: |
|
|
|
list, the error list. |
|
|
|
|