| @@ -56,6 +56,7 @@ class ServerStatus(enum.Enum): | |||
| RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | |||
| WAITING = 'waiting' # the client session is ready | |||
| RUNNING = 'running' # the client session is running a script | |||
| MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched | |||
| @enum.unique | |||
| @@ -17,6 +17,7 @@ import copy | |||
| from functools import wraps | |||
| import mindinsight | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| Streams, RunLevel | |||
| @@ -81,6 +82,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| log.warning("No graph received before WaitCMD.") | |||
| reply = get_ack_reply(1) | |||
| return reply | |||
| # send graph if it has not been sent before | |||
| self._pre_process(request) | |||
| # deal with old command | |||
| @@ -98,6 +100,17 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _pre_process(self, request): | |||
| """Pre-process before dealing with command.""" | |||
| # check if version is mismatch, if mismatch, send mismatch info to UI | |||
| if self._status == ServerStatus.MISMATCH: | |||
| log.warning("Version of Mindspore and Mindinsight re unmatched," | |||
| "waiting for user to terminate the script.") | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| # put metadata into data queue | |||
| metadata = metadata_stream.get(['state', 'debugger_version']) | |||
| self._cache_store.put_data(metadata) | |||
| return | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| is_new_step = metadata_stream.step < request.cur_step | |||
| is_new_node = metadata_stream.full_name != request.cur_node | |||
| @@ -252,10 +265,11 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| EventReply, the command event. | |||
| """ | |||
| log.info("Start to wait for command.") | |||
| self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' | |||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | |||
| if self._status != ServerStatus.MISMATCH: | |||
| self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value | |||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | |||
| event = None | |||
| while event is None and self._status == ServerStatus.WAITING: | |||
| while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]: | |||
| log.debug("Wait for %s-th command", self._pos) | |||
| event = self._get_next_command() | |||
| return event | |||
| @@ -337,16 +351,32 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| client_ip = context.peer().split(':', 1)[-1] | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| reply = get_ack_reply() | |||
| if request.training_done: | |||
| log.info("The training from %s has finished.", client_ip) | |||
| else: | |||
| ms_version = request.ms_version | |||
| if not ms_version: | |||
| ms_version = '1.0.0' | |||
| if version_match(ms_version, mindinsight.__version__) is False: | |||
| log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", | |||
| ms_version, mindinsight.__version__) | |||
| self._status = ServerStatus.MISMATCH | |||
| reply.version_matched = False | |||
| metadata_stream.state = 'mismatch' | |||
| else: | |||
| log.info("version is matched.") | |||
| reply.version_matched = True | |||
| metadata_stream.debugger_version = {'ms': ms_version, 'mi': mindinsight.__version__} | |||
| log.debug("Put ms_version from %s into cache.", client_ip) | |||
| metadata_stream.put(request) | |||
| metadata_stream.client_ip = client_ip | |||
| log.debug("Put new metadata from %s into cache.", client_ip) | |||
| # put metadata into data queue | |||
| metadata = metadata_stream.get() | |||
| self._cache_store.put_data(metadata) | |||
| reply = get_ack_reply() | |||
| log.debug("Send the reply to %s.", client_ip) | |||
| return reply | |||
| @@ -354,6 +384,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def SendGraph(self, request_iterator, context): | |||
| """Send graph into DebuggerCache.""" | |||
| log.info("Received graph.") | |||
| reply = get_ack_reply() | |||
| if self._status == ServerStatus.MISMATCH: | |||
| log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.") | |||
| return reply | |||
| serial_graph = b"" | |||
| for chunk in request_iterator: | |||
| serial_graph += chunk.buffer | |||
| @@ -364,14 +398,17 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| reply = get_ack_reply() | |||
| log.debug("Send the reply for graph.") | |||
| return reply | |||
| @debugger_wrap | |||
| def SendMultiGraphs(self, request_iterator, context): | |||
| """Send graph into DebuggerCache.""" | |||
| log.info("Received graph.") | |||
| log.info("Received multi_graphs.") | |||
| reply = get_ack_reply() | |||
| if self._status == ServerStatus.MISMATCH: | |||
| log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.") | |||
| return reply | |||
| serial_graph = b"" | |||
| graph_dict = {} | |||
| for chunk in request_iterator: | |||
| @@ -387,7 +424,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| reply = get_ack_reply() | |||
| log.debug("Send the reply for graph.") | |||
| return reply | |||
| @@ -461,3 +497,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._received_hit = watchpoint_hits | |||
| reply = get_ack_reply() | |||
| return reply | |||
| def version_match(mi_version, ms_version): | |||
| """Judge if the version of Mindinsight and Mindspore is matched""" | |||
| mi_major, mi_minor = mi_version.split('.')[:2] | |||
| ms_major, ms_minor = ms_version.split('.')[:2] | |||
| return mi_major == ms_major and mi_minor == ms_minor | |||
| @@ -41,6 +41,8 @@ message Metadata { | |||
| bool training_done = 5; | |||
| // the number of total graphs | |||
| int32 graph_num = 6; | |||
| // the version number of mindspore | |||
| string ms_version = 7; | |||
| } | |||
| message Chunk { | |||
| @@ -62,6 +64,7 @@ message EventReply { | |||
| RunCMD run_cmd = 3; | |||
| SetCMD set_cmd = 4; | |||
| ViewCMD view_cmd = 5; | |||
| bool version_matched = 6; | |||
| } | |||
| } | |||
| @@ -21,7 +21,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( | |||
| package='debugger', | |||
| syntax='proto3', | |||
| serialized_options=None, | |||
| serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"~\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xf4\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x88\x03\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3') | |||
| serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xf4\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x88\x03\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3') | |||
| , | |||
| dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | |||
| @@ -48,8 +48,8 @@ _EVENTREPLY_STATUS = _descriptor.EnumDescriptor( | |||
| ], | |||
| containing_type=None, | |||
| serialized_options=None, | |||
| serialized_start=460, | |||
| serialized_end=501, | |||
| serialized_start=508, | |||
| serialized_end=549, | |||
| ) | |||
| _sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS) | |||
| @@ -150,8 +150,8 @@ _WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor( | |||
| ], | |||
| containing_type=None, | |||
| serialized_options=None, | |||
| serialized_start=1008, | |||
| serialized_end=1400, | |||
| serialized_start=1056, | |||
| serialized_end=1448, | |||
| ) | |||
| _sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION) | |||
| @@ -205,6 +205,13 @@ _METADATA = _descriptor.Descriptor( | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='ms_version', full_name='debugger.Metadata.ms_version', index=6, | |||
| number=7, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| @@ -217,8 +224,8 @@ _METADATA = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=100, | |||
| serialized_end=226, | |||
| serialized_start=101, | |||
| serialized_end=247, | |||
| ) | |||
| @@ -255,8 +262,8 @@ _CHUNK = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=228, | |||
| serialized_end=269, | |||
| serialized_start=249, | |||
| serialized_end=290, | |||
| ) | |||
| @@ -302,6 +309,13 @@ _EVENTREPLY = _descriptor.Descriptor( | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='version_matched', full_name='debugger.EventReply.version_matched', index=5, | |||
| number=6, type=8, cpp_type=7, label=1, | |||
| has_default_value=False, default_value=False, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| @@ -318,8 +332,8 @@ _EVENTREPLY = _descriptor.Descriptor( | |||
| name='cmd', full_name='debugger.EventReply.cmd', | |||
| index=0, containing_type=None, fields=[]), | |||
| ], | |||
| serialized_start=272, | |||
| serialized_end=508, | |||
| serialized_start=293, | |||
| serialized_end=556, | |||
| ) | |||
| @@ -366,8 +380,8 @@ _RUNCMD = _descriptor.Descriptor( | |||
| name='cmd', full_name='debugger.RunCMD.cmd', | |||
| index=0, containing_type=None, fields=[]), | |||
| ], | |||
| serialized_start=510, | |||
| serialized_end=586, | |||
| serialized_start=558, | |||
| serialized_end=634, | |||
| ) | |||
| @@ -418,8 +432,8 @@ _SETCMD = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=589, | |||
| serialized_end=718, | |||
| serialized_start=637, | |||
| serialized_end=766, | |||
| ) | |||
| @@ -449,8 +463,8 @@ _VIEWCMD = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=720, | |||
| serialized_end=769, | |||
| serialized_start=768, | |||
| serialized_end=817, | |||
| ) | |||
| @@ -508,8 +522,8 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=912, | |||
| serialized_end=1005, | |||
| serialized_start=960, | |||
| serialized_end=1053, | |||
| ) | |||
| _WATCHCONDITION = _descriptor.Descriptor( | |||
| @@ -553,8 +567,8 @@ _WATCHCONDITION = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=772, | |||
| serialized_end=1400, | |||
| serialized_start=820, | |||
| serialized_end=1448, | |||
| ) | |||
| @@ -591,8 +605,8 @@ _WATCHNODE = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=1402, | |||
| serialized_end=1451, | |||
| serialized_start=1450, | |||
| serialized_end=1499, | |||
| ) | |||
| @@ -643,8 +657,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=1454, | |||
| serialized_end=1591, | |||
| serialized_start=1502, | |||
| serialized_end=1639, | |||
| ) | |||
| _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | |||
| @@ -664,6 +678,9 @@ _EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_n | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['view_cmd']) | |||
| _EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['version_matched']) | |||
| _EVENTREPLY.fields_by_name['version_matched'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _RUNCMD.oneofs_by_name['cmd'].fields.append( | |||
| _RUNCMD.fields_by_name['run_steps']) | |||
| _RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] | |||
| @@ -769,8 +786,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor( | |||
| file=DESCRIPTOR, | |||
| index=0, | |||
| serialized_options=None, | |||
| serialized_start=1594, | |||
| serialized_end=1979, | |||
| serialized_start=1642, | |||
| serialized_end=2027, | |||
| methods=[ | |||
| _descriptor.MethodDescriptor( | |||
| name='WaitCMD', | |||
| @@ -34,6 +34,7 @@ class MetadataHandler(StreamHandlerBase): | |||
| # If recommendation_confirmed is true, it only means the user has answered yes or no to the question, | |||
| # it does not necessarily mean that the user will use the recommended watch points. | |||
| self._recommendation_confirmed = False | |||
| self._debugger_version = {} | |||
| @property | |||
| def device_name(self): | |||
| @@ -135,6 +136,22 @@ class MetadataHandler(StreamHandlerBase): | |||
| """ | |||
| self._recommendation_confirmed = value | |||
| @property | |||
| def debugger_version(self): | |||
| """The property of debugger_version.""" | |||
| return self._debugger_version | |||
| @debugger_version.setter | |||
| def debugger_version(self, value): | |||
| """ | |||
| Set the property of debugger_version. | |||
| Args: | |||
| value (dict): The semantic versioning of mindinsight and mindspore, | |||
| format is {'ms': 'x.x.x', 'mi': 'x.x.x'}. | |||
| """ | |||
| self._debugger_version = value | |||
| def put(self, value): | |||
| """ | |||
| Put value into metadata cache. Called by grpc server. | |||
| @@ -170,7 +187,8 @@ class MetadataHandler(StreamHandlerBase): | |||
| 'backend': self.backend, | |||
| 'enable_recheck': self.enable_recheck, | |||
| 'graph_name': self.graph_name, | |||
| 'recommendation_confirmed': self._recommendation_confirmed | |||
| 'recommendation_confirmed': self._recommendation_confirmed, | |||
| 'debugger_version': self.debugger_version | |||
| } | |||
| else: | |||
| if not isinstance(filter_condition, list): | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}} | |||
| {"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}} | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.1.0", "mi": "1.1.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||
| @@ -1,45 +1 @@ | |||
| { | |||
| "metadata": { | |||
| "state": "waiting", | |||
| "step": 1, | |||
| "device_name": "0", | |||
| "node_name": "", | |||
| "backend": "Ascend", | |||
| "enable_recheck": false, | |||
| "graph_name": "", | |||
| "recommendation_confirmed": false | |||
| }, | |||
| "graph": { | |||
| "graph_names": [ | |||
| "graph_0", | |||
| "graph_1" | |||
| ], | |||
| "nodes": [ | |||
| { | |||
| "name": "graph_0", | |||
| "type": "name_scope", | |||
| "attr": {}, | |||
| "input": {}, | |||
| "output": {}, | |||
| "output_i": 0, | |||
| "proxy_input": {}, | |||
| "proxy_output": {}, | |||
| "subnode_count": 2, | |||
| "independent_layout": false | |||
| }, | |||
| { | |||
| "name": "graph_1", | |||
| "type": "name_scope", | |||
| "attr": {}, | |||
| "input": {}, | |||
| "output": {}, | |||
| "output_i": 0, | |||
| "proxy_input": {}, | |||
| "proxy_output": {}, | |||
| "subnode_count": 2, | |||
| "independent_layout": false | |||
| } | |||
| ] | |||
| }, | |||
| "watch_points": [] | |||
| } | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.1.0", "mi": "1.1.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||
| @@ -0,0 +1 @@ | |||
| {"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}} | |||
| @@ -28,7 +28,7 @@ from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE | |||
| class MockDebuggerClient: | |||
| """Mocked Debugger client.""" | |||
| def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1): | |||
| def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1, ms_version='1.1.0'): | |||
| channel = grpc.insecure_channel(hostname) | |||
| self.stub = EventListenerStub(channel) | |||
| self.flag = True | |||
| @@ -38,6 +38,7 @@ class MockDebuggerClient: | |||
| self._cur_node = '' | |||
| self._backend = backend | |||
| self._graph_num = graph_num | |||
| self._ms_version = ms_version | |||
| def _clean(self): | |||
| """Clean cache.""" | |||
| @@ -113,6 +114,7 @@ class MockDebuggerClient: | |||
| metadata.cur_node = self._cur_node | |||
| metadata.backend = self._backend | |||
| metadata.training_done = training_done | |||
| metadata.ms_version = self._ms_version | |||
| return metadata | |||
| def send_metadata_cmd(self, training_done=False): | |||
| @@ -121,6 +123,8 @@ class MockDebuggerClient: | |||
| metadata = self.get_metadata_cmd(training_done) | |||
| response = self.stub.SendMetadata(metadata) | |||
| assert response.status == EventReply.Status.OK | |||
| if response.version_matched is False: | |||
| self.command_loop() | |||
| if training_done is False: | |||
| self.send_graph_cmd() | |||
| print("finish") | |||
| @@ -519,7 +519,7 @@ class TestGPUDebugger: | |||
| class TestMultiGraphDebugger: | |||
| """Test debugger on Ascend backend.""" | |||
| """Test debugger on Ascend backend for multi_graph.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| @@ -673,3 +673,27 @@ def create_watchpoint_and_wait(app_client): | |||
| assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} | |||
| # wait for server has received watchpoint hit | |||
| check_waiting_state(app_client) | |||
| class TestMismatchDebugger: | |||
| """Test debugger when Mindinsight and Mindspore is mismatched.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Setup class.""" | |||
| cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0') | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.parametrize("body_data, expect_file", [ | |||
| ({'mode': 'all'}, 'version_mismatch.json') | |||
| ]) | |||
| def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file): | |||
| """Test retrieve when train_begin.""" | |||
| url = 'retrieve' | |||
| with self._debugger_client.get_thread_instance(): | |||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||
| send_terminate_cmd(app_client) | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}, "graph": {}, "watch_points": []} | |||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []} | |||
| @@ -211,8 +211,17 @@ class TestDebuggerGrpcServer: | |||
| def test_send_matadata(self): | |||
| """Test SendMatadata interface.""" | |||
| res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock()) | |||
| assert res == get_ack_reply() | |||
| res = self._server.SendMetadata(MagicMock(training_done=False, ms_version='1.1.0'), MagicMock()) | |||
| expect_reply = get_ack_reply() | |||
| expect_reply.version_matched = True | |||
| assert res == expect_reply | |||
| def test_send_matadata_with_mismatched(self): | |||
| """Test SendMatadata interface.""" | |||
| res = self._server.SendMetadata(MagicMock(training_done=False, ms_version='1.0.0'), MagicMock()) | |||
| expect_reply = get_ack_reply() | |||
| expect_reply.version_matched = False | |||
| assert res == expect_reply | |||
| def test_send_matadata_with_training_done(self): | |||
| """Test SendMatadata interface.""" | |||