You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the utils."""
  16. import enum
  17. import numpy as np
  18. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  19. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
  20. # translate the MindSpore type to numpy type.
  21. NUMPY_TYPE_MAP = {
  22. 'DT_BOOL': np.bool,
  23. 'DT_INT8': np.int8,
  24. 'DT_INT16': np.int16,
  25. 'DT_INT32': np.int32,
  26. 'DT_INT64': np.int64,
  27. 'DT_UINT8': np.uint8,
  28. 'DT_UINT16': np.uint16,
  29. 'DT_UINT32': np.uint32,
  30. 'DT_UINT64': np.uint64,
  31. 'DT_FLOAT16': np.float16,
  32. 'DT_FLOAT32': np.float32,
  33. 'DT_FLOAT64': np.float64,
  34. 'DT_STRING': np.str,
  35. 'DT_TYPE': np.str
  36. }
  37. MS_VERSION = '1.0.x'
  38. @enum.unique
  39. class ReplyStates(enum.Enum):
  40. """Define the status of reply."""
  41. SUCCESS = 0
  42. FAILED = -1
  43. @enum.unique
  44. class ServerStatus(enum.Enum):
  45. """The status of debugger server."""
  46. PENDING = 'pending' # no client session has been connected
  47. RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
  48. WAITING = 'waiting' # the client session is ready
  49. RUNNING = 'running' # the client session is running a script
  50. MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
  51. SENDING = 'sending' # the request is in cache but not be sent to client
  52. @enum.unique
  53. class Streams(enum.Enum):
  54. """Define the enable streams to be deal with."""
  55. COMMAND = "command"
  56. DATA = "data"
  57. METADATA = "metadata"
  58. GRAPH = 'node'
  59. TENSOR = 'tensor'
  60. WATCHPOINT = 'watchpoint'
  61. WATCHPOINT_HIT = 'watchpoint_hit'
  62. DEVICE = 'device'
  63. class RunLevel(enum.Enum):
  64. """Run Level enum, it depends on whether the program is executed node by node,
  65. step by step, or in recheck phase"""
  66. NODE = "node"
  67. STEP = "step"
  68. RECHECK = "recheck"
  69. def get_ack_reply(state=0):
  70. """The the ack EventReply."""
  71. reply = EventReply()
  72. state_mapping = {
  73. 0: EventReply.Status.OK,
  74. 1: EventReply.Status.FAILED,
  75. 2: EventReply.Status.PENDING
  76. }
  77. reply.status = state_mapping[state]
  78. return reply
  79. def wrap_reply_response(error_code=None, error_message=None):
  80. """
  81. Wrap reply response.
  82. Args:
  83. error_code (str): Error code. Default: None.
  84. error_message (str): Error message. Default: None.
  85. Returns:
  86. str, serialized response.
  87. """
  88. if error_code is None:
  89. reply = {'state': ReplyStates.SUCCESS.value}
  90. else:
  91. reply = {
  92. 'state': ReplyStates.FAILED.value,
  93. 'error_code': error_code,
  94. 'error_message': error_message
  95. }
  96. return reply
  97. def create_view_event_from_tensor_basic_info(tensors_info):
  98. """
  99. Create view event reply according to tensor names.
  100. Args:
  101. tensors_info (list[TensorBasicInfo]): The list of TensorBasicInfo. Each element has keys:
  102. `full_name`, `node_type`, `iter`.
  103. Returns:
  104. EventReply, the event reply with view cmd.
  105. """
  106. view_event = get_ack_reply()
  107. for tensor_info in tensors_info:
  108. node_type = tensor_info.node_type
  109. if node_type == NodeTypeEnum.CONST.value:
  110. continue
  111. truncate_tag = node_type == NodeTypeEnum.PARAMETER.value
  112. tensor_name = tensor_info.full_name
  113. # create view command
  114. ms_tensor = view_event.view_cmd.tensors.add()
  115. ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
  116. ms_tensor.truncate = truncate_tag
  117. ms_tensor.iter = tensor_info.iter
  118. return view_event
  119. def is_scope_type(node_type):
  120. """Judge whether the type is scope type."""
  121. return node_type.endswith('scope')
  122. def is_cst_type(node_type):
  123. """Judge whether the type is const type."""
  124. return node_type == NodeTypeEnum.CONST.value
  125. def version_match(ms_version, mi_version):
  126. """Judge if the version of Mindinsight and Mindspore is matched."""
  127. if not ms_version:
  128. ms_version = MS_VERSION
  129. mi_major, mi_minor = mi_version.split('.')[:2]
  130. ms_major, ms_minor = ms_version.split('.')[:2]
  131. return mi_major == ms_major and mi_minor == ms_minor
  132. @enum.unique
  133. class DebuggerServerMode(enum.Enum):
  134. """Debugger Server Mode."""
  135. ONLINE = 'online'
  136. OFFLINE = 'offline'
  137. class DumpSettings(enum.Enum):
  138. """Dump settings."""
  139. E2E_DUMP_SETTINGS = 'e2e_dump_settings'
  140. COMMON_DUMP_SETTINGS = 'common_dump_settings'
  141. ASYNC_DUMP_SETTINGS = 'async_dump_settings'