You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

event_handler.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the message handler."""
  16. import uuid
  17. from queue import Queue, Empty
  18. from threading import Lock
  19. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  20. from mindinsight.debugger.common.log import LOGGER as log
  21. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  22. class EventHandler(StreamHandlerBase):
  23. """Message Handler."""
  24. max_limit = 1000 # the max number of items in cache
  25. def __init__(self):
  26. self._prev_flag = str(uuid.uuid4())
  27. self._cur_flag = str(uuid.uuid4())
  28. self._next_idx = 0
  29. self._event_cache = [None] * self.max_limit
  30. self._pending_requests = {}
  31. self._lock = Lock()
  32. @property
  33. def next_pos(self):
  34. """The next pos to be updated in cache."""
  35. return ':'.join([self._cur_flag, str(self._next_idx)])
  36. def has_pos(self, pos):
  37. """Get the event according to pos."""
  38. cur_flag, cur_idx = self._parse_pos(pos)
  39. event = self._event_cache[cur_idx]
  40. if event is not None:
  41. if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \
  42. (cur_flag == self._prev_flag and cur_idx >= self._next_idx):
  43. return event
  44. return None
  45. def clean(self):
  46. """Clean event cache."""
  47. with self._lock:
  48. self._prev_flag = str(uuid.uuid4())
  49. self._cur_flag = str(uuid.uuid4())
  50. self._next_idx = 0
  51. self._event_cache = [None] * self.max_limit
  52. value = {'metadata': {'pos': '0'}}
  53. self.clean_pending_requests(value)
  54. log.debug("Clean event cache. %d request is waiting.", len(self._pending_requests))
  55. def put(self, value):
  56. """
  57. Put value into event_cache.
  58. Args:
  59. value (dict): The event to be put into cache.
  60. """
  61. if not isinstance(value, dict):
  62. log.error("Dict type required when put event message.")
  63. raise DebuggerParamValueError("Dict type required when put event message.")
  64. with self._lock:
  65. log.debug("Put the %d-th message into queue. \n %d requests is waiting.",
  66. self._next_idx, len(self._pending_requests))
  67. cur_pos = self._next_idx
  68. # update next pos
  69. self._next_idx += 1
  70. if self._next_idx >= self.max_limit:
  71. self._next_idx = 0
  72. self._prev_flag = self._cur_flag
  73. self._cur_flag = str(uuid.uuid4())
  74. # set next pos
  75. if not value.get('metadata'):
  76. value['metadata'] = {}
  77. value['metadata']['pos'] = self.next_pos
  78. self._event_cache[cur_pos] = value
  79. # feed the value for pending requests
  80. self.clean_pending_requests(value)
  81. def clean_pending_requests(self, value):
  82. """Clean pending requests."""
  83. for _, request in self._pending_requests.items():
  84. request.put(value)
  85. self._pending_requests = {}
  86. def get(self, filter_condition=None):
  87. """
  88. Get the pos-th value from event_cache according to filter_condition.
  89. Args:
  90. filter_condition (str): The index of event in cache. Default: None.
  91. Returns:
  92. object, the pos-th event.
  93. """
  94. flag, idx = self._parse_pos(filter_condition)
  95. cur_id = str(uuid.uuid4())
  96. with self._lock:
  97. # reset the pos after the cache is re-initialized.
  98. if not flag or flag not in [self._cur_flag, self._prev_flag]:
  99. idx = 0
  100. # get event from cache immediately
  101. if idx != self._next_idx and self._event_cache[idx]:
  102. return self._event_cache[idx]
  103. # wait for the event
  104. cur_queue = Queue(maxsize=1)
  105. self._pending_requests[cur_id] = cur_queue
  106. # block until event has been received
  107. event = self._wait_for_event(cur_id, cur_queue, filter_condition)
  108. return event
  109. def _parse_pos(self, pos):
  110. """Get next pos according to input position."""
  111. elements = pos.split(':')
  112. try:
  113. idx = int(elements[-1])
  114. except ValueError:
  115. log.error("Invalid index. The index in pos should be digit but get pos:%s", pos)
  116. raise DebuggerParamValueError("Invalid pos.")
  117. if idx < 0 or idx >= self.max_limit:
  118. log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit)
  119. raise DebuggerParamValueError(f"Invalid pos. {idx}")
  120. flag = elements[0] if len(elements) == 2 else ''
  121. return flag, idx
  122. def _wait_for_event(self, cur_id, cur_queue, pos):
  123. """Wait for the pos-th event."""
  124. try:
  125. # set the timeout to 25 seconds which is less the the timeout limit from UI
  126. event = cur_queue.get(timeout=25)
  127. except Empty:
  128. event = None
  129. if event is None:
  130. with self._lock:
  131. if self._pending_requests.get(cur_id):
  132. self._pending_requests.pop(cur_id)
  133. log.debug("Clean timeout request. Left pending requests: %d",
  134. len(self._pending_requests))
  135. event = {'metadata': {'pos': pos}}
  136. return event