diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py index e6478683..dd13588b 100644 --- a/mindinsight/debugger/stream_handler/tensor_handler.py +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -223,22 +223,22 @@ class TensorHandler(StreamHandlerBase): node_type = tensor_info.get('node_type') basic_info = self._get_basic_info(tensor_name, node_type) # add `has_prev_step` field to tensor basic info. - missing_tensor_infos = self._update_has_prev_step_field(basic_info, tensor_name, node_type) + missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type) if basic_info: tensor_info.update(basic_info) - if missing_tensor_infos: - missed_tensors.extend(missing_tensor_infos) + if missing_tensors_info: + missed_tensors.extend(missing_tensors_info) return missed_tensors def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): """Update has_prev_step field in tensor info.""" - missing_tensor_infos = self.get_missing_tensor_info(tensor_name, node_type) - if not missing_tensor_infos and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: + missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type) + if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: tensor_info['has_prev_step'] = True - return missing_tensor_infos + return missing_tensors_info - def get_missing_tensor_info(self, tensor_name, node_type): + def _get_missing_tensor_info(self, tensor_name, node_type): """ Get missing tensor infos. @@ -250,16 +250,16 @@ class TensorHandler(StreamHandlerBase): list, list of missing tensor basic information. """ step = self.cur_step - missing_tensor_infos = [] + missing_tensors_info = [] # check the current step value is missing if self._is_tensor_value_missing(tensor_name, step): - missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='')) + missing_tensors_info.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='')) log.debug("Add current step view cmd for %s", tensor_name) # check the previous step value is missing if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(tensor_name, step - 1): - missing_tensor_infos.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev')) + missing_tensors_info.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev')) log.debug("Add previous view cmd for %s", tensor_name) - return missing_tensor_infos + return missing_tensors_info def _is_tensor_value_missing(self, tensor_name, step): """ @@ -374,19 +374,22 @@ class TensorHandler(StreamHandlerBase): stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) return stats_info - def get_tensor_statistics(self, tensor_name, node_type): + def get_tensor_info_for_tensor_graph(self, tensor_name, node_type): """ - Get Tensor statistics. + Get Tensor info for tensor graphs. Args: tensor_name (str): Tensor name, format like `node_name:slot`. node_type (str): Node type. Returns: - dict, overall statistics. + dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. + list, list of missing tensor basic information. """ res = {} tensor = self._get_tensor(tensor_name, node_type) if tensor and not tensor.empty: - res = tensor.get_tensor_statistics() - return res + res['statistics'] = tensor.get_tensor_statistics() + res['shape'] = tensor.shape + missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type) + return res, missing_tensors diff --git a/mindinsight/debugger/stream_operator/tensor_detail_info.py b/mindinsight/debugger/stream_operator/tensor_detail_info.py index 32b48ed5..a785478a 100644 --- a/mindinsight/debugger/stream_operator/tensor_detail_info.py +++ b/mindinsight/debugger/stream_operator/tensor_detail_info.py @@ -78,14 +78,14 @@ class TensorDetailInfo: node['graph_name'] = graph_name for slot_info in node.get('slots', []): self._add_watchpoint_hit_info(slot_info, node) - self._add_statistic_info(slot_info, node, missing_tensors) + self._add_tensor_info(slot_info, node, missing_tensors) # query missing tensor values from client self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) return graph def _add_watchpoint_hit_info(self, slot_info, node): """ - Get the watchpoint that the tensor hit. + Add watchpoint hit info for the tensor. Args: slot_info (dict): Slot object. @@ -94,9 +94,9 @@ class TensorDetailInfo: tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name)) - def _add_statistic_info(self, slot_info, node, missing_tensors): + def _add_tensor_info(self, slot_info, node, missing_tensors): """ - Get the watchpoint that the tensor hit. + Add the tensor info and query for missed tensors. Args: slot_info (dict): Slot object. @@ -105,10 +105,10 @@ class TensorDetailInfo: """ tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) node_type = node.get('type') - slot_info['statistics'] = self._tensor_stream.get_tensor_statistics(tensor_name, node_type) - if not slot_info.get('statistics'): + tensor_info, cur_missing_tensors = self._tensor_stream.get_tensor_info_for_tensor_graph(tensor_name, node_type) + slot_info.update(tensor_info) + if cur_missing_tensors: log.debug("Get missing tensor basic infos for %s", tensor_name) - cur_missing_tensors = self._tensor_stream.get_missing_tensor_info(tensor_name, node_type) missing_tensors.extend(cur_missing_tensors) def _ask_for_missing_tensor_value(self, missing_tensors, tensor_name, graph_name): diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-0.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-0.json index dc909afa..eda1ce4c 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-0.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-0.json @@ -30,7 +30,23 @@ "slots": [ { "slot": "0", - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ], + "has_prev_step": true } ], "graph_name": "graph_0" @@ -64,7 +80,22 @@ "slots": [ { "slot": "0", - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ] } ], "graph_name": "graph_0" @@ -128,7 +159,22 @@ } } ], - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ] } ], "graph_name": "graph_0" diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-1.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-1.json index a54d3e19..c6c88dfd 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-1.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_graph-1.json @@ -27,11 +27,41 @@ "slots": [ { "slot": "0", - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ] }, { "slot": "1", - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ] } ], "graph_name": "graph_0" @@ -62,7 +92,23 @@ "slots": [ { "slot": "0", - "statistics": {} + "statistics": { + "overall_max": 6.0, + "overall_min": 1.0, + "overall_avg": 3.5, + "overall_count": 6, + "overall_nan_count": 0, + "overall_neg_inf_count": 0, + "overall_pos_inf_count": 0, + "overall_zero_count": 0.0, + "overall_neg_zero_count": 0.0, + "overall_pos_zero_count": 6.0 + }, + "shape": [ + 2, + 3 + ], + "has_prev_step": true } ], "graph_name": "graph_0" diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index e45165d8..65a8bbc6 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -382,6 +382,11 @@ class TestAscendDebugger: url = 'tensor-graphs' with self._debugger_client.get_thread_instance(): create_watchpoint_and_wait(app_client) + get_request_result(app_client, url, body_data, method='GET') + # check full tensor history from poll data + res = get_request_result( + app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get') + assert res.get('receive_tensor', {}).get('tensor_name') == body_data.get('tensor_name') send_and_compare_result(app_client, url, body_data, expect_file, method='GET') send_terminate_cmd(app_client)