| @@ -68,9 +68,10 @@ class PlatformEnum(Enum): | |||||
| class TargetTypeEnum(Enum): | class TargetTypeEnum(Enum): | ||||
| """Target types.""" | """Target types.""" | ||||
| TENSOR = 'tensor' | TENSOR = 'tensor' | ||||
| WEIGHT = 'weight' | |||||
| ACTIVATION = 'activation' | ACTIVATION = 'activation' | ||||
| GRADIENT = 'gradient' | GRADIENT = 'gradient' | ||||
| PARAMETER = 'parameter' | |||||
| WEIGHT = 'weight' | |||||
| class ParamTypeEnum(Enum): | class ParamTypeEnum(Enum): | ||||
| @@ -21,6 +21,7 @@ import mindinsight | |||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | ||||
| Streams, RunLevel | Streams, RunLevel | ||||
| from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | ||||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | ||||
| @@ -117,9 +118,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # clean cache data at the beginning of new step or node has been changed. | # clean cache data at the beginning of new step or node has been changed. | ||||
| if is_new_step or is_new_node: | if is_new_step or is_new_node: | ||||
| self._cache_store.clean_data() | self._cache_store.clean_data() | ||||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||||
| if is_new_step: | if is_new_step: | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | ||||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||||
| # receive graph at the beginning of the training | # receive graph at the beginning of the training | ||||
| if self._status == ServerStatus.RECEIVE_GRAPH: | if self._status == ServerStatus.RECEIVE_GRAPH: | ||||
| self._send_graph_flag(metadata_stream) | self._send_graph_flag(metadata_stream) | ||||
| @@ -397,6 +398,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | ||||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | ||||
| self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | ||||
| self._record_parameter_names() | |||||
| self._status = ServerStatus.RECEIVE_GRAPH | self._status = ServerStatus.RECEIVE_GRAPH | ||||
| log.debug("Send the reply for graph.") | log.debug("Send the reply for graph.") | ||||
| return reply | return reply | ||||
| @@ -423,10 +425,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| sub_graph.const_vals) | sub_graph.const_vals) | ||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | ||||
| self._record_parameter_names() | |||||
| self._status = ServerStatus.RECEIVE_GRAPH | self._status = ServerStatus.RECEIVE_GRAPH | ||||
| log.debug("Send the reply for graph.") | log.debug("Send the reply for graph.") | ||||
| return reply | return reply | ||||
| def _record_parameter_names(self): | |||||
| """Record parameter full names in tensor handler.""" | |||||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_searched_nodes( | |||||
| pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||||
| for nodes in parameter_nodes.values(): | |||||
| tensor_names = [node.full_name + ':0' for node in nodes] | |||||
| tensor_stream.record_parameter_names(tensor_names) | |||||
| @debugger_wrap | @debugger_wrap | ||||
| def SendTensors(self, request_iterator, context): | def SendTensors(self, request_iterator, context): | ||||
| """Send tensors into DebuggerCache.""" | """Send tensors into DebuggerCache.""" | ||||
| @@ -66,6 +66,19 @@ class NodeTypeIdentifier: | |||||
| return self.identify_func(*args, **kwargs) | return self.identify_func(*args, **kwargs) | ||||
| def is_parameter_node(node): | |||||
| """ | |||||
| Check if the node is weight type. | |||||
| Args: | |||||
| node (Node): The node object. | |||||
| Returns: | |||||
| bool, if the node is weight type. | |||||
| """ | |||||
| return bool(node.type == NodeTypeEnum.PARAMETER.value) | |||||
| def is_weight_node(node): | def is_weight_node(node): | ||||
| """ | """ | ||||
| Check if the node is weight type. | Check if the node is weight type. | ||||
| @@ -27,11 +27,17 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison | |||||
| TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | ||||
| class TensorHandler(StreamHandlerBase): | class TensorHandler(StreamHandlerBase): | ||||
| """Metadata Handler.""" | """Metadata Handler.""" | ||||
| def __init__(self): | def __init__(self): | ||||
| # the collection of parameter full names | |||||
| self._param_names = set() | |||||
| # const value objects, the format is like: dict[<const name>, <OpTensor object>] | |||||
| self._const_vals = {} | self._const_vals = {} | ||||
| # tensor values, the format is like: | |||||
| # dict[<tensor full name>, dict[<step_num>, <OpTensor object>]] | |||||
| self._tensors = {} | self._tensors = {} | ||||
| self._cur_step = 0 | self._cur_step = 0 | ||||
| @@ -147,6 +153,19 @@ class TensorHandler(StreamHandlerBase): | |||||
| const_tensor = ConstTensor(const_val) | const_tensor = ConstTensor(const_val) | ||||
| self._const_vals[const_tensor.name] = const_tensor | self._const_vals[const_tensor.name] = const_tensor | ||||
| def record_parameter_names(self, names): | |||||
| """ | |||||
| Record parameter names. | |||||
| Note: | |||||
| Parameter values could be changed during an iteration step. It must be cleaned after each node step. | |||||
| Args: | |||||
| names (list[str]): List of tensor full names. | |||||
| """ | |||||
| self._param_names.update(names) | |||||
| log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names)) | |||||
| def get(self, filter_condition=None): | def get(self, filter_condition=None): | ||||
| """ | """ | ||||
| Get full tensor value. | Get full tensor value. | ||||
| @@ -293,7 +312,13 @@ class TensorHandler(StreamHandlerBase): | |||||
| def clean_tensors(self, cur_step): | def clean_tensors(self, cur_step): | ||||
| """Clean the tensor cache.""" | """Clean the tensor cache.""" | ||||
| self._cur_step = cur_step | |||||
| if cur_step != self._cur_step: | |||||
| self._cur_step = cur_step | |||||
| self._clean_expired_tensors(cur_step) | |||||
| self._clean_parameters() | |||||
| def _clean_expired_tensors(self, cur_step): | |||||
| """Clean expired tensors less than current steps.""" | |||||
| expired_tensor = [] | expired_tensor = [] | ||||
| for tensor_name, tensor in self._tensors.items(): | for tensor_name, tensor in self._tensors.items(): | ||||
| expired_step = [step for step in tensor.keys() if step <= cur_step - 2] | expired_step = [step for step in tensor.keys() if step <= cur_step - 2] | ||||
| @@ -304,6 +329,13 @@ class TensorHandler(StreamHandlerBase): | |||||
| for tensor_name in expired_tensor: | for tensor_name in expired_tensor: | ||||
| self._tensors.pop(tensor_name) | self._tensors.pop(tensor_name) | ||||
| def _clean_parameters(self): | |||||
| """Clean parameter cache.""" | |||||
| for param in self._param_names: | |||||
| if param in self._tensors: | |||||
| self._tensors.pop(param) | |||||
| log.debug("Clean param %s in cache.", param) | |||||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | def get_tensors_diff(self, tensor_name, shape, tolerance=0): | ||||
| """ | """ | ||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | Get tensor comparisons data for given name, detail, shape and tolerance. | ||||
| @@ -123,11 +123,10 @@ class MockDebuggerClient: | |||||
| metadata = self.get_metadata_cmd(training_done) | metadata = self.get_metadata_cmd(training_done) | ||||
| response = self.stub.SendMetadata(metadata) | response = self.stub.SendMetadata(metadata) | ||||
| assert response.status == EventReply.Status.OK | assert response.status == EventReply.Status.OK | ||||
| if response.version_matched is False: | |||||
| if response.HasField('version_matched') and response.version_matched is False: | |||||
| self.command_loop() | self.command_loop() | ||||
| if training_done is False: | if training_done is False: | ||||
| self.send_graph_cmd() | self.send_graph_cmd() | ||||
| print("finish") | |||||
| def send_graph_cmd(self): | def send_graph_cmd(self): | ||||
| """Send graph to debugger server.""" | """Send graph to debugger server.""" | ||||
| @@ -672,6 +672,7 @@ def create_watchpoint_and_wait(app_client): | |||||
| # wait for server has received watchpoint hit | # wait for server has received watchpoint hit | ||||
| check_state(app_client) | check_state(app_client) | ||||
| class TestMismatchDebugger: | class TestMismatchDebugger: | ||||
| """Test debugger when Mindinsight and Mindspore is mismatched.""" | """Test debugger when Mindinsight and Mindspore is mismatched.""" | ||||