diff --git a/mindinsight/debugger/conditionmgr/condition.py b/mindinsight/debugger/conditionmgr/condition.py index 50686b0d..e75cedc8 100644 --- a/mindinsight/debugger/conditionmgr/condition.py +++ b/mindinsight/debugger/conditionmgr/condition.py @@ -68,9 +68,10 @@ class PlatformEnum(Enum): class TargetTypeEnum(Enum): """Target types.""" TENSOR = 'tensor' - WEIGHT = 'weight' ACTIVATION = 'activation' GRADIENT = 'gradient' + PARAMETER = 'parameter' + WEIGHT = 'weight' class ParamTypeEnum(Enum): diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index 3cbaf30f..00055569 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -21,6 +21,7 @@ import mindinsight from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ Streams, RunLevel +from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto @@ -117,9 +118,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # clean cache data at the beginning of new step or node has been changed. if is_new_step or is_new_node: self._cache_store.clean_data() + self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) if is_new_step: self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() - self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) # receive graph at the beginning of the training if self._status == ServerStatus.RECEIVE_GRAPH: self._send_graph_flag(metadata_stream) @@ -397,6 +398,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name + self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply @@ -423,10 +425,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): sub_graph.const_vals) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) + self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply + def _record_parameter_names(self): + """Record parameter full names in tensor handler.""" + parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_searched_nodes( + pattern={'node_category': TargetTypeEnum.PARAMETER.value}) + tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) + for nodes in parameter_nodes.values(): + tensor_names = [node.full_name + ':0' for node in nodes] + tensor_stream.record_parameter_names(tensor_names) + @debugger_wrap def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" diff --git a/mindinsight/debugger/stream_cache/node_type_identifier.py b/mindinsight/debugger/stream_cache/node_type_identifier.py index 66bffd1f..9f4e68e8 100644 --- a/mindinsight/debugger/stream_cache/node_type_identifier.py +++ b/mindinsight/debugger/stream_cache/node_type_identifier.py @@ -66,6 +66,19 @@ class NodeTypeIdentifier: return self.identify_func(*args, **kwargs) +def is_parameter_node(node): + """ + Check if the node is weight type. + + Args: + node (Node): The node object. + + Returns: + bool, if the node is weight type. + """ + return bool(node.type == NodeTypeEnum.PARAMETER.value) + + def is_weight_node(node): """ Check if the node is weight type. diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py index dd13588b..154fece6 100644 --- a/mindinsight/debugger/stream_handler/tensor_handler.py +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -27,11 +27,17 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) + class TensorHandler(StreamHandlerBase): """Metadata Handler.""" def __init__(self): + # the collection of parameter full names + self._param_names = set() + # const value objects, the format is like: dict[, ] self._const_vals = {} + # tensor values, the format is like: + # dict[, dict[, ]] self._tensors = {} self._cur_step = 0 @@ -147,6 +153,19 @@ class TensorHandler(StreamHandlerBase): const_tensor = ConstTensor(const_val) self._const_vals[const_tensor.name] = const_tensor + def record_parameter_names(self, names): + """ + Record parameter names. + + Note: + Parameter values could be changed during an iteration step. It must be cleaned after each node step. + + Args: + names (list[str]): List of tensor full names. + """ + self._param_names.update(names) + log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names)) + def get(self, filter_condition=None): """ Get full tensor value. @@ -293,7 +312,13 @@ class TensorHandler(StreamHandlerBase): def clean_tensors(self, cur_step): """Clean the tensor cache.""" - self._cur_step = cur_step + if cur_step != self._cur_step: + self._cur_step = cur_step + self._clean_expired_tensors(cur_step) + self._clean_parameters() + + def _clean_expired_tensors(self, cur_step): + """Clean expired tensors less than current steps.""" expired_tensor = [] for tensor_name, tensor in self._tensors.items(): expired_step = [step for step in tensor.keys() if step <= cur_step - 2] @@ -304,6 +329,13 @@ class TensorHandler(StreamHandlerBase): for tensor_name in expired_tensor: self._tensors.pop(tensor_name) + def _clean_parameters(self): + """Clean parameter cache.""" + for param in self._param_names: + if param in self._tensors: + self._tensors.pop(param) + log.debug("Clean param %s in cache.", param) + def get_tensors_diff(self, tensor_name, shape, tolerance=0): """ Get tensor comparisons data for given name, detail, shape and tolerance. diff --git a/tests/st/func/debugger/mock_ms_client.py b/tests/st/func/debugger/mock_ms_client.py index 30090792..b558b359 100644 --- a/tests/st/func/debugger/mock_ms_client.py +++ b/tests/st/func/debugger/mock_ms_client.py @@ -123,11 +123,10 @@ class MockDebuggerClient: metadata = self.get_metadata_cmd(training_done) response = self.stub.SendMetadata(metadata) assert response.status == EventReply.Status.OK - if response.version_matched is False: + if response.HasField('version_matched') and response.version_matched is False: self.command_loop() if training_done is False: self.send_graph_cmd() - print("finish") def send_graph_cmd(self): """Send graph to debugger server.""" diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index 1e5e8fdf..1b102191 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -672,6 +672,7 @@ def create_watchpoint_and_wait(app_client): # wait for server has received watchpoint hit check_state(app_client) + class TestMismatchDebugger: """Test debugger when Mindinsight and Mindspore is mismatched."""