You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the utils."""
  16. import enum
  17. import numpy as np
  18. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  19. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
  20. # translate the MindSpore type to numpy type.
  21. NUMPY_TYPE_MAP = {
  22. 'DT_BOOL': np.bool,
  23. 'DT_INT8': np.int8,
  24. 'DT_INT16': np.int16,
  25. 'DT_INT32': np.int32,
  26. 'DT_INT64': np.int64,
  27. 'DT_UINT8': np.uint8,
  28. 'DT_UINT16': np.uint16,
  29. 'DT_UINT32': np.uint32,
  30. 'DT_UINT64': np.uint64,
  31. 'DT_FLOAT16': np.float16,
  32. 'DT_FLOAT32': np.float32,
  33. 'DT_FLOAT64': np.float64,
  34. 'DT_STRING': np.str
  35. }
  36. @enum.unique
  37. class ReplyStates(enum.Enum):
  38. """Define the status of reply."""
  39. SUCCESS = 0
  40. FAILED = -1
  41. @enum.unique
  42. class ServerStatus(enum.Enum):
  43. """The status of debugger server."""
  44. PENDING = 'pending' # no client session has been connected
  45. RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
  46. WAITING = 'waiting' # the client session is ready
  47. RUNNING = 'running' # the client session is running a script
  48. MISMATCH = 'mismatch' # the version of Mindspore and Mindinsight is not matched
  49. SENDING = 'sending' # the request is in cache but not be sent to client
  50. @enum.unique
  51. class Streams(enum.Enum):
  52. """Define the enable streams to be deal with."""
  53. COMMAND = "command"
  54. DATA = "data"
  55. METADATA = "metadata"
  56. GRAPH = 'node'
  57. TENSOR = 'tensor'
  58. WATCHPOINT = 'watchpoint'
  59. WATCHPOINT_HIT = 'watchpoint_hit'
  60. class RunLevel(enum.Enum):
  61. """Run Level enum, it depends on whether the program is executed node by node,
  62. step by step, or in recheck phase"""
  63. NODE = "node"
  64. STEP = "step"
  65. RECHECK = "recheck"
  66. def get_ack_reply(state=0):
  67. """The the ack EventReply."""
  68. reply = EventReply()
  69. state_mapping = {
  70. 0: EventReply.Status.OK,
  71. 1: EventReply.Status.FAILED,
  72. 2: EventReply.Status.PENDING
  73. }
  74. reply.status = state_mapping[state]
  75. return reply
  76. def wrap_reply_response(error_code=None, error_message=None):
  77. """
  78. Wrap reply response.
  79. Args:
  80. error_code (str): Error code. Default: None.
  81. error_message (str): Error message. Default: None.
  82. Returns:
  83. str, serialized response.
  84. """
  85. if error_code is None:
  86. reply = {'state': ReplyStates.SUCCESS.value}
  87. else:
  88. reply = {
  89. 'state': ReplyStates.FAILED.value,
  90. 'error_code': error_code,
  91. 'error_message': error_message
  92. }
  93. return reply
  94. def create_view_event_from_tensor_basic_info(tensors_info):
  95. """
  96. Create view event reply according to tensor names.
  97. Args:
  98. tensors_info (list[TensorBasicInfo]): The list of TensorBasicInfo. Each element has keys:
  99. `full_name`, `node_type`, `iter`.
  100. Returns:
  101. EventReply, the event reply with view cmd.
  102. """
  103. view_event = get_ack_reply()
  104. for tensor_info in tensors_info:
  105. node_type = tensor_info.node_type
  106. if node_type == NodeTypeEnum.CONST.value:
  107. continue
  108. truncate_tag = node_type == NodeTypeEnum.PARAMETER.value
  109. tensor_name = tensor_info.full_name
  110. # create view command
  111. ms_tensor = view_event.view_cmd.tensors.add()
  112. ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
  113. ms_tensor.truncate = truncate_tag
  114. ms_tensor.iter = tensor_info.iter
  115. return view_event
  116. def is_scope_type(node_type):
  117. """Judge whether the type is scope type."""
  118. return node_type.endswith('scope')
  119. def is_cst_type(node_type):
  120. """Judge whether the type is const type."""
  121. return node_type == NodeTypeEnum.CONST.value