You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_cache.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger data cache manager."""
  16. import sys
  17. from mindinsight.debugger.common.log import LOGGER as log
  18. from mindinsight.debugger.common.utils import Streams
  19. from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \
  20. TensorHandler, WatchpointHandler, WatchpointHitHandler
  21. STREAM_HANDLER_MAP = {
  22. Streams.COMMAND.value: EventHandler,
  23. Streams.DATA.value: EventHandler,
  24. Streams.METADATA.value: MetadataHandler,
  25. Streams.GRAPH.value: GraphHandler,
  26. Streams.TENSOR.value: TensorHandler,
  27. Streams.WATCHPOINT.value: WatchpointHandler,
  28. Streams.WATCHPOINT_HIT.value: WatchpointHitHandler
  29. }
  30. class DebuggerCache:
  31. """The debugger data cache manager."""
  32. def __init__(self):
  33. self._stream_handler = {}
  34. def initialize(self):
  35. """Initialize the stream handlers."""
  36. self._stream_handler = {}
  37. for stream in Streams:
  38. mode = stream.value
  39. stream_handler = STREAM_HANDLER_MAP.get(mode)
  40. self._stream_handler[mode] = stream_handler()
  41. def clean(self):
  42. """Clean cache for all stream."""
  43. for _, stream_handler in self._stream_handler.items():
  44. stream_handler.clean()
  45. def get_stream_handler(self, mode):
  46. """
  47. Get the stream handler object.
  48. Args:
  49. mode (Streams): The type of stream handler.
  50. Returns:
  51. StreamHandlerBase, the stream handler object.
  52. """
  53. return self._stream_handler.get(mode.value)
  54. def _get(self, mode, pos):
  55. """
  56. Get updated data or command from cache.
  57. Args:
  58. mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`.
  59. pos (int): The index of info.
  60. Returns:
  61. object, the pos-th message about `mode` type of info.
  62. """
  63. stream_handler = self.get_stream_handler(mode)
  64. return stream_handler.get(pos)
  65. def _put(self, mode, value):
  66. """
  67. Set updated data or command from cache.
  68. Args:
  69. mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`.
  70. value (object): The info to be record in cache.
  71. """
  72. stream_handler = self.get_stream_handler(mode)
  73. return stream_handler.put(value)
  74. def get_command(self, pos):
  75. """
  76. Get the pos-th command in command stream.
  77. Args:
  78. pos (int): The index of command.
  79. Returns:
  80. int, the position of next message.
  81. EventReply, the command object.
  82. """
  83. content = self._get(Streams.COMMAND, pos)
  84. next_pos = content.get('metadata').get('pos')
  85. reply = content.get('cmd')
  86. return next_pos, reply
  87. def put_command(self, cmd):
  88. """
  89. Set command to command stream.
  90. Args:
  91. cmd (EventReply): The command EventReply.
  92. """
  93. log.debug("Set command %s", cmd)
  94. return self._put(Streams.COMMAND, {'cmd': cmd})
  95. def has_command(self, pos):
  96. """Judge if the number of command is no less than `pos`."""
  97. event = self.get_stream_handler(Streams.COMMAND).has_pos(pos)
  98. return event
  99. def clean_command(self):
  100. """Clean command queue."""
  101. self.get_stream_handler(Streams.COMMAND).clean()
  102. log.debug("Clean command.")
  103. def clean_data(self):
  104. """Clean command queue."""
  105. self.get_stream_handler(Streams.DATA).clean()
  106. log.debug("Clean data queue.")
  107. def get_data(self, pos):
  108. """
  109. Get updated data from data stream.
  110. Args:
  111. pos (int): The index of data.
  112. Returns:
  113. object, updated data_value.
  114. """
  115. return self._get(Streams.DATA, pos)
  116. def put_data(self, value):
  117. """
  118. Set updated data to data stream.
  119. Args:
  120. value (dict): The updated data.
  121. """
  122. log.debug("Set <%d> bytes data", sys.getsizeof(value))
  123. return self._put(Streams.DATA, value)