| @@ -68,9 +68,10 @@ class PlatformEnum(Enum): | |||
| class TargetTypeEnum(Enum): | |||
| """Target types.""" | |||
| TENSOR = 'tensor' | |||
| WEIGHT = 'weight' | |||
| ACTIVATION = 'activation' | |||
| GRADIENT = 'gradient' | |||
| PARAMETER = 'parameter' | |||
| WEIGHT = 'weight' | |||
| class ParamTypeEnum(Enum): | |||
| @@ -21,6 +21,7 @@ import mindinsight | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| Streams, RunLevel | |||
| from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | |||
| @@ -117,9 +118,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # clean cache data at the beginning of new step or node has been changed. | |||
| if is_new_step or is_new_node: | |||
| self._cache_store.clean_data() | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||
| if is_new_step: | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||
| # receive graph at the beginning of the training | |||
| if self._status == ServerStatus.RECEIVE_GRAPH: | |||
| self._send_graph_flag(metadata_stream) | |||
| @@ -397,6 +398,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | |||
| self._record_parameter_names() | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| log.debug("Send the reply for graph.") | |||
| return reply | |||
| @@ -423,10 +425,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| sub_graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||
| self._record_parameter_names() | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| log.debug("Send the reply for graph.") | |||
| return reply | |||
| def _record_parameter_names(self): | |||
| """Record parameter full names in tensor handler.""" | |||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_searched_nodes( | |||
| pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||
| for nodes in parameter_nodes.values(): | |||
| tensor_names = [node.full_name + ':0' for node in nodes] | |||
| tensor_stream.record_parameter_names(tensor_names) | |||
| @debugger_wrap | |||
| def SendTensors(self, request_iterator, context): | |||
| """Send tensors into DebuggerCache.""" | |||
| @@ -66,6 +66,19 @@ class NodeTypeIdentifier: | |||
| return self.identify_func(*args, **kwargs) | |||
| def is_parameter_node(node): | |||
| """ | |||
| Check if the node is weight type. | |||
| Args: | |||
| node (Node): The node object. | |||
| Returns: | |||
| bool, if the node is weight type. | |||
| """ | |||
| return bool(node.type == NodeTypeEnum.PARAMETER.value) | |||
| def is_weight_node(node): | |||
| """ | |||
| Check if the node is weight type. | |||
| @@ -27,11 +27,17 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison | |||
| TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | |||
| class TensorHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| def __init__(self): | |||
| # the collection of parameter full names | |||
| self._param_names = set() | |||
| # const value objects, the format is like: dict[<const name>, <OpTensor object>] | |||
| self._const_vals = {} | |||
| # tensor values, the format is like: | |||
| # dict[<tensor full name>, dict[<step_num>, <OpTensor object>]] | |||
| self._tensors = {} | |||
| self._cur_step = 0 | |||
| @@ -147,6 +153,19 @@ class TensorHandler(StreamHandlerBase): | |||
| const_tensor = ConstTensor(const_val) | |||
| self._const_vals[const_tensor.name] = const_tensor | |||
| def record_parameter_names(self, names): | |||
| """ | |||
| Record parameter names. | |||
| Note: | |||
| Parameter values could be changed during an iteration step. It must be cleaned after each node step. | |||
| Args: | |||
| names (list[str]): List of tensor full names. | |||
| """ | |||
| self._param_names.update(names) | |||
| log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names)) | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get full tensor value. | |||
| @@ -293,7 +312,13 @@ class TensorHandler(StreamHandlerBase): | |||
| def clean_tensors(self, cur_step): | |||
| """Clean the tensor cache.""" | |||
| self._cur_step = cur_step | |||
| if cur_step != self._cur_step: | |||
| self._cur_step = cur_step | |||
| self._clean_expired_tensors(cur_step) | |||
| self._clean_parameters() | |||
| def _clean_expired_tensors(self, cur_step): | |||
| """Clean expired tensors less than current steps.""" | |||
| expired_tensor = [] | |||
| for tensor_name, tensor in self._tensors.items(): | |||
| expired_step = [step for step in tensor.keys() if step <= cur_step - 2] | |||
| @@ -304,6 +329,13 @@ class TensorHandler(StreamHandlerBase): | |||
| for tensor_name in expired_tensor: | |||
| self._tensors.pop(tensor_name) | |||
| def _clean_parameters(self): | |||
| """Clean parameter cache.""" | |||
| for param in self._param_names: | |||
| if param in self._tensors: | |||
| self._tensors.pop(param) | |||
| log.debug("Clean param %s in cache.", param) | |||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||
| """ | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| @@ -123,11 +123,10 @@ class MockDebuggerClient: | |||
| metadata = self.get_metadata_cmd(training_done) | |||
| response = self.stub.SendMetadata(metadata) | |||
| assert response.status == EventReply.Status.OK | |||
| if response.version_matched is False: | |||
| if response.HasField('version_matched') and response.version_matched is False: | |||
| self.command_loop() | |||
| if training_done is False: | |||
| self.send_graph_cmd() | |||
| print("finish") | |||
| def send_graph_cmd(self): | |||
| """Send graph to debugger server.""" | |||
| @@ -672,6 +672,7 @@ def create_watchpoint_and_wait(app_client): | |||
| # wait for server has received watchpoint hit | |||
| check_state(app_client) | |||
| class TestMismatchDebugger: | |||
| """Test debugger when Mindinsight and Mindspore is mismatched.""" | |||