You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the utils."""
  16. import enum
  17. import numpy as np
  18. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  19. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
  20. # translate the MindSpore type to numpy type.
  21. NUMPY_TYPE_MAP = {
  22. 'DT_BOOL': np.bool,
  23. 'DT_INT8': np.int8,
  24. 'DT_INT16': np.int16,
  25. 'DT_INT32': np.int32,
  26. 'DT_INT64': np.int64,
  27. 'DT_UINT8': np.uint8,
  28. 'DT_UINT16': np.uint16,
  29. 'DT_UINT32': np.uint32,
  30. 'DT_UINT64': np.uint64,
  31. 'DT_FLOAT16': np.float16,
  32. 'DT_FLOAT32': np.float32,
  33. 'DT_FLOAT64': np.float64,
  34. 'DT_STRING': np.str
  35. }
  36. @enum.unique
  37. class ReplyStates(enum.Enum):
  38. """Define the status of reply."""
  39. SUCCESS = 0
  40. FAILED = -1
  41. @enum.unique
  42. class ServerStatus(enum.Enum):
  43. """The status of debugger server."""
  44. PENDING = 'pending' # no client session has been connected
  45. RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
  46. WAITING = 'waiting' # the client session is ready
  47. RUNNING = 'running' # the client session is running a script
  48. @enum.unique
  49. class Streams(enum.Enum):
  50. """Define the enable streams to be deal with."""
  51. COMMAND = "command"
  52. DATA = "data"
  53. METADATA = "metadata"
  54. GRAPH = 'node'
  55. TENSOR = 'tensor'
  56. WATCHPOINT = 'watchpoint'
  57. WATCHPOINT_HIT = 'watchpoint_hit'
  58. class RunLevel(enum.Enum):
  59. """Run Level enum, it depends on whether the program is executed node by node,
  60. step by step, or in recheck phase"""
  61. NODE = "node"
  62. STEP = "step"
  63. RECHECK = "recheck"
  64. def get_ack_reply(state=0):
  65. """The the ack EventReply."""
  66. reply = EventReply()
  67. state_mapping = {
  68. 0: EventReply.Status.OK,
  69. 1: EventReply.Status.FAILED,
  70. 2: EventReply.Status.PENDING
  71. }
  72. reply.status = state_mapping[state]
  73. return reply
  74. def wrap_reply_response(error_code=None, error_message=None):
  75. """
  76. Wrap reply response.
  77. Args:
  78. error_code (str): Error code. Default: None.
  79. error_message (str): Error message. Default: None.
  80. Returns:
  81. str, serialized response.
  82. """
  83. if error_code is None:
  84. reply = {'state': ReplyStates.SUCCESS.value}
  85. else:
  86. reply = {
  87. 'state': ReplyStates.FAILED.value,
  88. 'error_code': error_code,
  89. 'error_message': error_message
  90. }
  91. return reply
  92. def create_view_event_from_tensor_history(tensor_history):
  93. """
  94. Create view event reply according to tensor names.
  95. Args:
  96. tensor_history (list[dict]): The list of tensor history. Each element has keys:
  97. `name`, `node_type`.
  98. Returns:
  99. EventReply, the event reply with view cmd.
  100. """
  101. view_event = get_ack_reply()
  102. for tensor_info in tensor_history:
  103. node_type = tensor_info.get('node_type')
  104. if node_type == NodeTypeEnum.CONST.value:
  105. continue
  106. truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value
  107. tensor_name = tensor_info.get('full_name', '')
  108. # create view command
  109. ms_tensor = view_event.view_cmd.tensors.add()
  110. ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
  111. ms_tensor.truncate = truncate_tag
  112. ms_tensor.iter = 'prev' if tensor_info.get('iter') else ''
  113. return view_event
  114. def is_scope_type(node_type):
  115. """Judge whether the type is scope type."""
  116. return node_type.endswith('scope')