You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the utils."""
  16. import enum
  17. from collections import namedtuple
  18. import numpy as np
  19. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  20. from mindinsight.debugger.common.log import logger as log
  21. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
  22. from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum
  23. # translate the MindSpore type to numpy type.
  24. NUMPY_TYPE_MAP = {
  25. 'DT_BOOL': np.bool,
  26. 'DT_INT8': np.int8,
  27. 'DT_INT16': np.int16,
  28. 'DT_INT32': np.int32,
  29. 'DT_INT64': np.int64,
  30. 'DT_UINT8': np.uint8,
  31. 'DT_UINT16': np.uint16,
  32. 'DT_UINT32': np.uint32,
  33. 'DT_UINT64': np.uint64,
  34. 'DT_FLOAT16': np.float16,
  35. 'DT_FLOAT32': np.float32,
  36. 'DT_FLOAT64': np.float64,
  37. 'DT_STRING': np.str
  38. }
  39. @enum.unique
  40. class ReplyStates(enum.Enum):
  41. """Define the status of reply."""
  42. SUCCESS = 0
  43. FAILED = -1
  44. @enum.unique
  45. class ServerStatus(enum.Enum):
  46. """The status of debugger server."""
  47. PENDING = 'pending' # no client session has been connected
  48. RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
  49. WAITING = 'waiting' # the client session is ready
  50. RUNNING = 'running' # the client session is running a script
  51. @enum.unique
  52. class Streams(enum.Enum):
  53. """Define the enable streams to be deal with."""
  54. COMMAND = "command"
  55. DATA = "data"
  56. METADATA = "metadata"
  57. GRAPH = 'node'
  58. TENSOR = 'tensor'
  59. WATCHPOINT = 'watchpoint'
  60. WATCHPOINT_HIT = 'watchpoint_hit'
  61. NodeBasicInfo = namedtuple('node_basic_info', ['name', 'full_name', 'type'])
  62. def get_ack_reply(state=0):
  63. """The the ack EventReply."""
  64. reply = EventReply()
  65. state_mapping = {
  66. 0: EventReply.Status.OK,
  67. 1: EventReply.Status.FAILED,
  68. 2: EventReply.Status.PENDING
  69. }
  70. reply.status = state_mapping[state]
  71. return reply
  72. def wrap_reply_response(error_code=None, error_message=None):
  73. """
  74. Wrap reply response.
  75. Args:
  76. error_code (str): Error code. Default: None.
  77. error_message (str): Error message. Default: None.
  78. Returns:
  79. str, serialized response.
  80. """
  81. if error_code is None:
  82. reply = {'state': ReplyStates.SUCCESS.value}
  83. else:
  84. reply = {
  85. 'state': ReplyStates.FAILED.value,
  86. 'error_code': error_code,
  87. 'error_message': error_message
  88. }
  89. return reply
  90. def create_view_event_from_tensor_history(tensor_history):
  91. """
  92. Create view event reply according to tensor names.
  93. Args:
  94. tensor_history (list[dict]): The list of tensor history. Each element has keys:
  95. `name`, `node_type`.
  96. Returns:
  97. EventReply, the event reply with view cmd.
  98. """
  99. view_event = get_ack_reply()
  100. for tensor_info in tensor_history:
  101. node_type = tensor_info.get('node_type')
  102. if node_type == NodeTypeEnum.CONST.value:
  103. continue
  104. truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value
  105. tensor_name = tensor_info.get('full_name', '')
  106. # create view command
  107. ms_tensor = view_event.view_cmd.tensors.add()
  108. ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
  109. ms_tensor.truncate = truncate_tag
  110. ms_tensor.iter = 'prev' if tensor_info.get('iter') else ''
  111. return view_event
  112. def is_scope_type(node_type):
  113. """Judge whether the type is scope type."""
  114. scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value]
  115. return node_type in scope_types
  116. def str_to_slice_or_int(input_str):
  117. """
  118. Translate param from string to slice or int.
  119. Args:
  120. input_str (str): The string to be translated.
  121. Returns:
  122. Union[int, slice], the transformed param.
  123. """
  124. try:
  125. if ':' in input_str:
  126. ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':')))
  127. else:
  128. ret = int(input_str)
  129. except ValueError as err:
  130. log.error("Failed to create slice from %s", input_str)
  131. log.exception(err)
  132. raise DebuggerParamValueError("Invalid shape.")
  133. return ret