Browse Source

Add version_check for Mindinsight and Mindspore

tags/v1.1.0
maning202007 5 years ago
parent
commit
357602a12a
15 changed files with 165 additions and 1183 deletions
  1. +1
    -0
      mindinsight/debugger/common/utils.py
  2. +50
    -7
      mindinsight/debugger/debugger_grpc_server.py
  3. +3
    -0
      mindinsight/debugger/proto/debug_grpc.proto
  4. +44
    -27
      mindinsight/debugger/proto/debug_grpc_pb2.py
  5. +19
    -1
      mindinsight/debugger/stream_handler/metadata_handler.py
  6. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/before_train_begin.json
  7. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/multi_next_node.json
  8. +1
    -45
      tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json
  9. +1
    -548
      tests/st/func/debugger/expect_results/restful_results/retrieve_all.json
  10. +1
    -548
      tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json
  11. +1
    -0
      tests/st/func/debugger/expect_results/restful_results/version_mismatch.json
  12. +5
    -1
      tests/st/func/debugger/mock_ms_client.py
  13. +25
    -1
      tests/st/func/debugger/test_restful_api.py
  14. +1
    -1
      tests/ut/debugger/expected_results/debugger_server/retrieve_all.json
  15. +11
    -2
      tests/ut/debugger/test_debugger_grpc_server.py

+ 1
- 0
mindinsight/debugger/common/utils.py View File

@@ -56,6 +56,7 @@ class ServerStatus(enum.Enum):
RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
WAITING = 'waiting' # the client session is ready
RUNNING = 'running' # the client session is running a script
MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched


@enum.unique


+ 50
- 7
mindinsight/debugger/debugger_grpc_server.py View File

@@ -17,6 +17,7 @@ import copy

from functools import wraps

import mindinsight
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
Streams, RunLevel
@@ -81,6 +82,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.warning("No graph received before WaitCMD.")
reply = get_ack_reply(1)
return reply

# send graph if it has not been sent before
self._pre_process(request)
# deal with old command
@@ -98,6 +100,17 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

def _pre_process(self, request):
"""Pre-process before dealing with command."""

# check if version is mismatch, if mismatch, send mismatch info to UI
if self._status == ServerStatus.MISMATCH:
log.warning("Version of Mindspore and Mindinsight re unmatched,"
"waiting for user to terminate the script.")
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
# put metadata into data queue
metadata = metadata_stream.get(['state', 'debugger_version'])
self._cache_store.put_data(metadata)
return

metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
is_new_step = metadata_stream.step < request.cur_step
is_new_node = metadata_stream.full_name != request.cur_node
@@ -252,10 +265,11 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
EventReply, the command event.
"""
log.info("Start to wait for command.")
self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
self._cache_store.put_data({'metadata': {'state': 'waiting'}})
if self._status != ServerStatus.MISMATCH:
self._cache_store.get_stream_handler(Streams.METADATA).state = ServerStatus.WAITING.value
self._cache_store.put_data({'metadata': {'state': 'waiting'}})
event = None
while event is None and self._status == ServerStatus.WAITING:
while event is None and self._status in [ServerStatus.MISMATCH, ServerStatus.WAITING]:
log.debug("Wait for %s-th command", self._pos)
event = self._get_next_command()
return event
@@ -337,16 +351,32 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

client_ip = context.peer().split(':', 1)[-1]
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
reply = get_ack_reply()
if request.training_done:
log.info("The training from %s has finished.", client_ip)
else:
ms_version = request.ms_version
if not ms_version:
ms_version = '1.0.0'
if version_match(ms_version, mindinsight.__version__) is False:
log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s",
ms_version, mindinsight.__version__)
self._status = ServerStatus.MISMATCH
reply.version_matched = False
metadata_stream.state = 'mismatch'
else:
log.info("version is matched.")
reply.version_matched = True

metadata_stream.debugger_version = {'ms': ms_version, 'mi': mindinsight.__version__}
log.debug("Put ms_version from %s into cache.", client_ip)

metadata_stream.put(request)
metadata_stream.client_ip = client_ip
log.debug("Put new metadata from %s into cache.", client_ip)
# put metadata into data queue
metadata = metadata_stream.get()
self._cache_store.put_data(metadata)
reply = get_ack_reply()
log.debug("Send the reply to %s.", client_ip)
return reply

@@ -354,6 +384,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def SendGraph(self, request_iterator, context):
"""Send graph into DebuggerCache."""
log.info("Received graph.")
reply = get_ack_reply()
if self._status == ServerStatus.MISMATCH:
log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.")
return reply
serial_graph = b""
for chunk in request_iterator:
serial_graph += chunk.buffer
@@ -364,14 +398,17 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
self._status = ServerStatus.RECEIVE_GRAPH
reply = get_ack_reply()
log.debug("Send the reply for graph.")
return reply

@debugger_wrap
def SendMultiGraphs(self, request_iterator, context):
"""Send graph into DebuggerCache."""
log.info("Received graph.")
log.info("Received multi_graphs.")
reply = get_ack_reply()
if self._status == ServerStatus.MISMATCH:
log.info("Mindspore and Mindinsight is unmatched, waiting for user to terminate the service.")
return reply
serial_graph = b""
graph_dict = {}
for chunk in request_iterator:
@@ -387,7 +424,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):

self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._status = ServerStatus.RECEIVE_GRAPH
reply = get_ack_reply()
log.debug("Send the reply for graph.")
return reply

@@ -461,3 +497,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._received_hit = watchpoint_hits
reply = get_ack_reply()
return reply


def version_match(mi_version, ms_version):
"""Judge if the version of Mindinsight and Mindspore is matched"""
mi_major, mi_minor = mi_version.split('.')[:2]
ms_major, ms_minor = ms_version.split('.')[:2]
return mi_major == ms_major and mi_minor == ms_minor

+ 3
- 0
mindinsight/debugger/proto/debug_grpc.proto View File

@@ -41,6 +41,8 @@ message Metadata {
bool training_done = 5;
// the number of total graphs
int32 graph_num = 6;
// the version number of mindspore
string ms_version = 7;
}

message Chunk {
@@ -62,6 +64,7 @@ message EventReply {
RunCMD run_cmd = 3;
SetCMD set_cmd = 4;
ViewCMD view_cmd = 5;
bool version_matched = 6;
}
}



+ 44
- 27
mindinsight/debugger/proto/debug_grpc_pb2.py View File

@@ -21,7 +21,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='debugger',
syntax='proto3',
serialized_options=None,
serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"~\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xf4\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x88\x03\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3')
serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xf4\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x88\x03\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3')
,
dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,])

@@ -48,8 +48,8 @@ _EVENTREPLY_STATUS = _descriptor.EnumDescriptor(
],
containing_type=None,
serialized_options=None,
serialized_start=460,
serialized_end=501,
serialized_start=508,
serialized_end=549,
)
_sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS)

@@ -150,8 +150,8 @@ _WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor(
],
containing_type=None,
serialized_options=None,
serialized_start=1008,
serialized_end=1400,
serialized_start=1056,
serialized_end=1448,
)
_sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION)

@@ -205,6 +205,13 @@ _METADATA = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='ms_version', full_name='debugger.Metadata.ms_version', index=6,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -217,8 +224,8 @@ _METADATA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=100,
serialized_end=226,
serialized_start=101,
serialized_end=247,
)


@@ -255,8 +262,8 @@ _CHUNK = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=228,
serialized_end=269,
serialized_start=249,
serialized_end=290,
)


@@ -302,6 +309,13 @@ _EVENTREPLY = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='version_matched', full_name='debugger.EventReply.version_matched', index=5,
number=6, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -318,8 +332,8 @@ _EVENTREPLY = _descriptor.Descriptor(
name='cmd', full_name='debugger.EventReply.cmd',
index=0, containing_type=None, fields=[]),
],
serialized_start=272,
serialized_end=508,
serialized_start=293,
serialized_end=556,
)


@@ -366,8 +380,8 @@ _RUNCMD = _descriptor.Descriptor(
name='cmd', full_name='debugger.RunCMD.cmd',
index=0, containing_type=None, fields=[]),
],
serialized_start=510,
serialized_end=586,
serialized_start=558,
serialized_end=634,
)


@@ -418,8 +432,8 @@ _SETCMD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=589,
serialized_end=718,
serialized_start=637,
serialized_end=766,
)


@@ -449,8 +463,8 @@ _VIEWCMD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=720,
serialized_end=769,
serialized_start=768,
serialized_end=817,
)


@@ -508,8 +522,8 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=912,
serialized_end=1005,
serialized_start=960,
serialized_end=1053,
)

_WATCHCONDITION = _descriptor.Descriptor(
@@ -553,8 +567,8 @@ _WATCHCONDITION = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=772,
serialized_end=1400,
serialized_start=820,
serialized_end=1448,
)


@@ -591,8 +605,8 @@ _WATCHNODE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=1402,
serialized_end=1451,
serialized_start=1450,
serialized_end=1499,
)


@@ -643,8 +657,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=1454,
serialized_end=1591,
serialized_start=1502,
serialized_end=1639,
)

_EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS
@@ -664,6 +678,9 @@ _EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_n
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['view_cmd'])
_EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['version_matched'])
_EVENTREPLY.fields_by_name['version_matched'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_RUNCMD.oneofs_by_name['cmd'].fields.append(
_RUNCMD.fields_by_name['run_steps'])
_RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd']
@@ -769,8 +786,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor(
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=1594,
serialized_end=1979,
serialized_start=1642,
serialized_end=2027,
methods=[
_descriptor.MethodDescriptor(
name='WaitCMD',


+ 19
- 1
mindinsight/debugger/stream_handler/metadata_handler.py View File

@@ -34,6 +34,7 @@ class MetadataHandler(StreamHandlerBase):
# If recommendation_confirmed is true, it only means the user has answered yes or no to the question,
# it does not necessarily mean that the user will use the recommended watch points.
self._recommendation_confirmed = False
self._debugger_version = {}

@property
def device_name(self):
@@ -135,6 +136,22 @@ class MetadataHandler(StreamHandlerBase):
"""
self._recommendation_confirmed = value

@property
def debugger_version(self):
"""The property of debugger_version."""
return self._debugger_version

@debugger_version.setter
def debugger_version(self, value):
"""
Set the property of debugger_version.

Args:
value (dict): The semantic versioning of mindinsight and mindspore,
format is {'ms': 'x.x.x', 'mi': 'x.x.x'}.
"""
self._debugger_version = value

def put(self, value):
"""
Put value into metadata cache. Called by grpc server.
@@ -170,7 +187,8 @@ class MetadataHandler(StreamHandlerBase):
'backend': self.backend,
'enable_recheck': self.enable_recheck,
'graph_name': self.graph_name,
'recommendation_confirmed': self._recommendation_confirmed
'recommendation_confirmed': self._recommendation_confirmed,
'debugger_version': self.debugger_version
}
else:
if not isinstance(filter_condition, list):


+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/before_train_begin.json View File

@@ -1 +1 @@
{"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}}
{"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/multi_next_node.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.1.0", "mi": "1.1.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}

+ 1
- 45
tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json View File

@@ -1,45 +1 @@
{
"metadata": {
"state": "waiting",
"step": 1,
"device_name": "0",
"node_name": "",
"backend": "Ascend",
"enable_recheck": false,
"graph_name": "",
"recommendation_confirmed": false
},
"graph": {
"graph_names": [
"graph_0",
"graph_1"
],
"nodes": [
{
"name": "graph_0",
"type": "name_scope",
"attr": {},
"input": {},
"output": {},
"output_i": 0,
"proxy_input": {},
"proxy_output": {},
"subnode_count": 2,
"independent_layout": false
},
{
"name": "graph_1",
"type": "name_scope",
"attr": {},
"input": {},
"output": {},
"output_i": 0,
"proxy_input": {},
"proxy_output": {},
"subnode_count": 2,
"independent_layout": false
}
]
},
"watch_points": []
}
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.1.0", "mi": "1.1.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}

+ 1
- 548
tests/st/func/debugger/expect_results/restful_results/retrieve_all.json
File diff suppressed because it is too large
View File


+ 1
- 548
tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json
File diff suppressed because it is too large
View File


+ 1
- 0
tests/st/func/debugger/expect_results/restful_results/version_mismatch.json View File

@@ -0,0 +1 @@
{"metadata": {"state": "pending", "step": 0, "device_name": "", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}}

+ 5
- 1
tests/st/func/debugger/mock_ms_client.py View File

@@ -28,7 +28,7 @@ from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
class MockDebuggerClient:
"""Mocked Debugger client."""

def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1):
def __init__(self, hostname='localhost:50051', backend='Ascend', graph_num=1, ms_version='1.1.0'):
channel = grpc.insecure_channel(hostname)
self.stub = EventListenerStub(channel)
self.flag = True
@@ -38,6 +38,7 @@ class MockDebuggerClient:
self._cur_node = ''
self._backend = backend
self._graph_num = graph_num
self._ms_version = ms_version

def _clean(self):
"""Clean cache."""
@@ -113,6 +114,7 @@ class MockDebuggerClient:
metadata.cur_node = self._cur_node
metadata.backend = self._backend
metadata.training_done = training_done
metadata.ms_version = self._ms_version
return metadata

def send_metadata_cmd(self, training_done=False):
@@ -121,6 +123,8 @@ class MockDebuggerClient:
metadata = self.get_metadata_cmd(training_done)
response = self.stub.SendMetadata(metadata)
assert response.status == EventReply.Status.OK
if response.version_matched is False:
self.command_loop()
if training_done is False:
self.send_graph_cmd()
print("finish")


+ 25
- 1
tests/st/func/debugger/test_restful_api.py View File

@@ -519,7 +519,7 @@ class TestGPUDebugger:


class TestMultiGraphDebugger:
"""Test debugger on Ascend backend."""
"""Test debugger on Ascend backend for multi_graph."""

@classmethod
def setup_class(cls):
@@ -673,3 +673,27 @@ def create_watchpoint_and_wait(app_client):
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
# wait for server has received watchpoint hit
check_waiting_state(app_client)

class TestMismatchDebugger:
"""Test debugger when Mindinsight and Mindspore is mismatched."""

@classmethod
def setup_class(cls):
"""Setup class."""
cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0')

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.parametrize("body_data, expect_file", [
({'mode': 'all'}, 'version_mismatch.json')
])
def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file):
"""Test retrieve when train_begin."""
url = 'retrieve'
with self._debugger_client.get_thread_instance():
send_and_compare_result(app_client, url, body_data, expect_file)
send_terminate_cmd(app_client)

+ 1
- 1
tests/ut/debugger/expected_results/debugger_server/retrieve_all.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false}, "graph": {}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []}

+ 11
- 2
tests/ut/debugger/test_debugger_grpc_server.py View File

@@ -211,8 +211,17 @@ class TestDebuggerGrpcServer:

def test_send_matadata(self):
"""Test SendMatadata interface."""
res = self._server.SendMetadata(MagicMock(training_done=False), MagicMock())
assert res == get_ack_reply()
res = self._server.SendMetadata(MagicMock(training_done=False, ms_version='1.1.0'), MagicMock())
expect_reply = get_ack_reply()
expect_reply.version_matched = True
assert res == expect_reply

def test_send_matadata_with_mismatched(self):
"""Test SendMatadata interface."""
res = self._server.SendMetadata(MagicMock(training_done=False, ms_version='1.0.0'), MagicMock())
expect_reply = get_ack_reply()
expect_reply.version_matched = False
assert res == expect_reply

def test_send_matadata_with_training_done(self):
"""Test SendMatadata interface."""


Loading…
Cancel
Save